From 8934b85c92379989311053dae2fcd9f8e2260ede Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 15 Jan 2017 14:58:29 +0100 Subject: [PATCH 001/159] client/posts: fix skipping duplicate uploads --- .../js/controllers/post_upload_controller.js | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/client/js/controllers/post_upload_controller.js b/client/js/controllers/post_upload_controller.js index cba0918b..45023413 100644 --- a/client/js/controllers/post_upload_controller.js +++ b/client/js/controllers/post_upload_controller.js @@ -95,16 +95,20 @@ class PostUploadController { return reverseSearchPromise.then(searchResult => { if (searchResult) { // notify about exact duplicate - if (searchResult.exactPost && !skipDuplicates) { - let error = new Error('Post already uploaded ' + - `(@${searchResult.exactPost.id})`); - error.uploadable = uploadable; - return Promise.reject(error); + if (searchResult.exactPost) { + if (skipDuplicates) { + this._view.removeUploadable(uploadable); + return Promise.resolve(); + } else { + let error = new Error('Post already uploaded ' + + `(@${searchResult.exactPost.id})`); + error.uploadable = uploadable; + return Promise.reject(error); + } } // notify about similar posts - if (!searchResult.exactPost && - searchResult.similarPosts.length) { + if (searchResult.similarPosts.length) { let error = new Error( `Found ${searchResult.similarPosts.length} similar ` + 'posts.\nYou can resume or discard this upload.'); From eead1560ee3f7749ffc8925e33a7bc81831c48a2 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 15 Jan 2017 21:09:08 +0100 Subject: [PATCH 002/159] client: fix reporting errors in pager --- client/js/views/manual_page_view.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/js/views/manual_page_view.js b/client/js/views/manual_page_view.js index 3507fd7b..94fc70b9 100644 --- a/client/js/views/manual_page_view.js +++ b/client/js/views/manual_page_view.js @@ -102,7 +102,7 @@ class ManualPageView { views.syncScrollPosition(); }, response => { - this.showError(response.description); + this.showError(response.message); }); } From 7414d1f7a683e3ced7a78ae8ed71318ab52ae9f8 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 20 Jan 2017 22:16:42 +0100 Subject: [PATCH 003/159] server/posts: fix getting posts around Querying this undocumented API resulted in 500 ISE unless the client asked only for the "id" field. --- server/szurubooru/search/configs/post_search_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index b0ea3300..cf02fecc 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -102,7 +102,7 @@ class PostSearchConfig(BaseSearchConfig): search_query.special_tokens = new_special_tokens def create_around_query(self): - return db.session.query(db.Post.post_id) + return db.session.query(db.Post).options(lazyload('*')) def create_filter_query(self, disable_eager_loads): strategy = lazyload if disable_eager_loads else subqueryload From b0e60a340bb983dab78aed14721d14e43534d54a Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 20 Jan 2017 22:38:56 +0100 Subject: [PATCH 004/159] client/home: centerize messages --- client/css/home-view.styl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/client/css/home-view.styl b/client/css/home-view.styl index e17e4bb5..06292aea 100644 --- a/client/css/home-view.styl +++ b/client/css/home-view.styl @@ -9,8 +9,10 @@ font-size: 30pt margin: 0 - .message - margin-bottom: 2em + .messages + text-align: center + .message + margin: 0 auto 2em auto form width: auto From 6714f05b49f314fe27e977b5555cb6c6cd6b1e17 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 20 Jan 2017 22:39:49 +0100 Subject: [PATCH 005/159] client/posts: remove bullets from post management --- client/css/post-main-view.styl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/client/css/post-main-view.styl b/client/css/post-main-view.styl index 689ec73c..b4d9b8fe 100644 --- a/client/css/post-main-view.styl +++ b/client/css/post-main-view.styl @@ -130,8 +130,13 @@ display: inline-block .management - li + ul + list-style-type: none margin: 0 + padding: 0 + li + margin: 0 + padding: 0 label margin-bottom: 0.3em From 1acceb941d84775cdeb0019bdfbdd63e9e125982 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 20 Jan 2017 21:51:04 +0100 Subject: [PATCH 006/159] client: refactor linking and routing Print all links through new uri.js component Refactor the router to use more predictable parsing Fix linking to entities with weird names (that contain slashes, + etc.) --- client/html/comment.tpl | 4 +- client/html/comments_page.tpl | 2 +- client/html/help.tpl | 10 +-- client/html/help_search.tpl | 8 +-- client/html/home.tpl | 2 +- client/html/home_footer.tpl | 2 +- client/html/login.tpl | 2 +- client/html/not_found.tpl | 2 +- client/html/post_detail.tpl | 4 +- client/html/post_readonly_sidebar.tpl | 4 +- client/html/posts_header.tpl | 2 +- client/html/posts_page.tpl | 2 +- client/html/settings.tpl | 2 +- client/html/tag.tpl | 8 +-- client/html/tag_category_row.tpl | 2 +- client/html/tag_delete.tpl | 2 +- client/html/tag_summary.tpl | 2 +- client/html/tags_header.tpl | 4 +- client/html/tags_page.tpl | 20 +++--- client/html/user.tpl | 6 +- client/html/user_registration.tpl | 2 +- client/html/user_summary.tpl | 10 +-- client/html/users_header.tpl | 2 +- client/html/users_page.tpl | 4 +- client/js/api.js | 1 + client/js/controllers/auth_controller.js | 9 +-- client/js/controllers/comments_controller.js | 7 +- client/js/controllers/help_controller.js | 6 +- client/js/controllers/home_controller.js | 2 +- client/js/controllers/not_found_controller.js | 2 +- .../controllers/password_reset_controller.js | 17 ++--- .../js/controllers/post_detail_controller.js | 10 ++- client/js/controllers/post_list_controller.js | 9 ++- client/js/controllers/post_main_controller.js | 15 ++-- .../js/controllers/post_upload_controller.js | 5 +- client/js/controllers/settings_controller.js | 2 +- client/js/controllers/snapshots_controller.js | 7 +- .../controllers/tag_categories_controller.js | 2 +- client/js/controllers/tag_controller.js | 17 +++-- client/js/controllers/tag_list_controller.js | 9 ++- client/js/controllers/user_controller.js | 14 ++-- client/js/controllers/user_list_controller.js | 9 ++- .../user_registration_controller.js | 5 +- client/js/controls/tag_input_control.js | 9 ++- client/js/main.js | 4 +- client/js/models/comment.js | 13 ++-- client/js/models/info.js | 3 +- client/js/models/post.js | 31 ++++++--- client/js/models/post_list.js | 33 +++++---- client/js/models/snapshot_list.js | 19 +++--- client/js/models/tag.js | 13 ++-- client/js/models/tag_category.js | 7 +- client/js/models/tag_category_list.js | 19 ++++-- client/js/models/tag_list.js | 26 ++++--- client/js/models/user.js | 11 ++- client/js/models/user_list.js | 19 +++--- client/js/router.js | 68 +++++++++++++------ client/js/util/misc.js | 38 +---------- client/js/util/uri.js | 62 +++++++++++++++++ client/js/util/views.js | 28 ++++---- client/js/views/home_view.js | 4 +- client/js/views/manual_page_view.js | 1 - client/js/views/post_main_view.js | 9 +-- client/js/views/users_header_view.js | 1 - client/package.json | 1 - 65 files changed, 380 insertions(+), 295 deletions(-) create mode 100644 client/js/util/uri.js diff --git a/client/html/comment.tpl b/client/html/comment.tpl index 04469c9a..6bea7045 100644 --- a/client/html/comment.tpl +++ b/client/html/comment.tpl @@ -1,7 +1,7 @@
<% if (ctx.user && ctx.user.name && ctx.canViewUsers) { %> - + '> <% } %> <%= ctx.makeThumbnail(ctx.user ? ctx.user.avatarUrl : null) %> @@ -23,7 +23,7 @@ diff --git a/client/html/help_search.tpl b/client/html/help_search.tpl index 8c22d73e..70737893 100644 --- a/client/html/help_search.tpl +++ b/client/html/help_search.tpl @@ -1,9 +1,9 @@ diff --git a/client/html/home.tpl b/client/html/home.tpl index f1d9b1c1..1e51b12b 100644 --- a/client/html/home.tpl +++ b/client/html/home.tpl @@ -8,7 +8,7 @@ <%= ctx.makeTextInput({name: 'search-text', placeholder: 'enter some tags'}) %> or - browse all posts + '>browse all posts <% } %> diff --git a/client/html/home_footer.tpl b/client/html/home_footer.tpl index 9ab9cd0c..8f9cb10a 100644 --- a/client/html/home_footer.tpl +++ b/client/html/home_footer.tpl @@ -2,6 +2,6 @@
  • <%- ctx.postCount %> posts
  • <%= ctx.makeFileSize(ctx.diskUsage) %>
  • Build <%- ctx.version %> from <%= ctx.makeRelativeTime(ctx.buildDate) %>
  • - <% if (ctx.canListSnapshots) { %>
  • History
  • + <% if (ctx.canListSnapshots) { %>
  • '>History
  • <% } %> diff --git a/client/html/login.tpl b/client/html/login.tpl index cc2a6805..8ccc439d 100644 --- a/client/html/login.tpl +++ b/client/html/login.tpl @@ -31,7 +31,7 @@
    <% if (ctx.canSendMails) { %> - Forgot the password? + '>Forgot the password? <% } %>
    diff --git a/client/html/not_found.tpl b/client/html/not_found.tpl index 15a7dfed..53bcdcd7 100644 --- a/client/html/not_found.tpl +++ b/client/html/not_found.tpl @@ -1,5 +1,5 @@

    Not found

    <%- ctx.path %> is not a valid URL.

    -

    Back to main page

    +

    Back to main page

    diff --git a/client/html/post_detail.tpl b/client/html/post_detail.tpl index 5e04ab22..65ff7868 100644 --- a/client/html/post_detail.tpl +++ b/client/html/post_detail.tpl @@ -2,9 +2,9 @@

    Post #<%- ctx.post.id %>

    diff --git a/client/html/post_readonly_sidebar.tpl b/client/html/post_readonly_sidebar.tpl index 32dafd6c..f29c8b65 100644 --- a/client/html/post_readonly_sidebar.tpl +++ b/client/html/post_readonly_sidebar.tpl @@ -67,14 +67,14 @@ --><% for (let tag of ctx.post.tags) { %>
  • <% if (ctx.canViewTags) { %>'>' class='<%= ctx.makeCssName(ctx.getTagCategory(tag), 'tag') %>'><% } %><% if (ctx.canViewTags) { %><% } %><% if (ctx.canListPosts) { %>'>' class='<%= ctx.makeCssName(ctx.getTagCategory(tag), 'tag') %>'><% } %><%- tag %> <% if (ctx.canListPosts) { %> diff --git a/client/html/tag_category_row.tpl b/client/html/tag_category_row.tpl index 3ec199c1..dcc8c16a 100644 --- a/client/html/tag_category_row.tpl +++ b/client/html/tag_category_row.tpl @@ -19,7 +19,7 @@ <% if (ctx.tagCategory.name) { %> - + '> <%- ctx.tagCategory.tagCount %> <% } else { %> diff --git a/client/html/tag_delete.tpl b/client/html/tag_delete.tpl index 7dbfc1e6..e7be8cf2 100644 --- a/client/html/tag_delete.tpl +++ b/client/html/tag_delete.tpl @@ -1,6 +1,6 @@
    -

    This tag has <%- ctx.tag.postCount %> usage(s).

    +

    This tag has '><%- ctx.tag.postCount %> usage(s).

    • diff --git a/client/html/tag_summary.tpl b/client/html/tag_summary.tpl index 6938e344..0513d643 100644 --- a/client/html/tag_summary.tpl +++ b/client/html/tag_summary.tpl @@ -36,6 +36,6 @@

      <%= ctx.makeMarkdown(ctx.tag.description || 'This tag has no description yet.') %> -

      This tag has <%- ctx.tag.postCount %> usage(s).

      +

      This tag has '><%- ctx.tag.postCount %> usage(s).

    diff --git a/client/html/tags_header.tpl b/client/html/tags_header.tpl index ed641279..59c5bcbc 100644 --- a/client/html/tags_header.tpl +++ b/client/html/tags_header.tpl @@ -8,9 +8,9 @@
    - Syntax help + '>Syntax help <% if (ctx.canEditTagCategories) { %> - Tag categories + '>Tag categories <% } %>
    diff --git a/client/html/tags_page.tpl b/client/html/tags_page.tpl index d66b342e..f27f49e3 100644 --- a/client/html/tags_page.tpl +++ b/client/html/tags_page.tpl @@ -4,37 +4,37 @@ <% if (ctx.query == 'sort:name' || !ctx.query) { %> - Tag name(s) + '>Tag name(s) <% } else { %> - Tag name(s) + '>Tag name(s) <% } %> <% if (ctx.query == 'sort:implication-count') { %> - Implications + '>Implications <% } else { %> - Implications + '>Implications <% } %> <% if (ctx.query == 'sort:suggestion-count') { %> - Suggestions + '>Suggestions <% } else { %> - Suggestions + '>Suggestions <% } %> <% if (ctx.query == 'sort:usages') { %> - Usages + '>Usages <% } else { %> - Usages + '>Usages <% } %> <% if (ctx.query == 'sort:creation-time') { %> - Created on + '>Created on <% } else { %> - Created on + '>Created on <% } %> diff --git a/client/html/user.tpl b/client/html/user.tpl index 4e7d00ae..28e34e67 100644 --- a/client/html/user.tpl +++ b/client/html/user.tpl @@ -2,12 +2,12 @@

    <%- ctx.user.name %>

    diff --git a/client/html/user_registration.tpl b/client/html/user_registration.tpl index e0e0f81c..a6d291f4 100644 --- a/client/html/user_registration.tpl +++ b/client/html/user_registration.tpl @@ -51,6 +51,6 @@
  • vote up/down on posts and comments

  • -

    By creating an account, you are agreeing to the Terms of Service.

    +

    By creating an account, you are agreeing to the '>Terms of Service.

    diff --git a/client/html/user_summary.tpl b/client/html/user_summary.tpl index b2d40ace..1a33a57f 100644 --- a/client/html/user_summary.tpl +++ b/client/html/user_summary.tpl @@ -10,9 +10,9 @@ @@ -20,8 +20,8 @@ <% } %> diff --git a/client/html/users_header.tpl b/client/html/users_header.tpl index 6cefe556..7faab8ea 100644 --- a/client/html/users_header.tpl +++ b/client/html/users_header.tpl @@ -8,7 +8,7 @@ diff --git a/client/html/users_page.tpl b/client/html/users_page.tpl index 61b07f84..6cb04d3a 100644 --- a/client/html/users_page.tpl +++ b/client/html/users_page.tpl @@ -4,7 +4,7 @@ -->
  • <% if (ctx.canViewUsers) { %> - + '> <% } %> <%= ctx.makeThumbnail(user.avatarUrl) %> <% if (ctx.canViewUsers) { %> @@ -12,7 +12,7 @@ <% } %>
    <% if (ctx.canViewUsers) { %> - + '> <% } %> <%- user.name %> <% if (ctx.canViewUsers) { %> diff --git a/client/js/api.js b/client/js/api.js index 0b0ed541..abfeb6f0 100644 --- a/client/js/api.js +++ b/client/js/api.js @@ -5,6 +5,7 @@ const request = require('superagent'); const config = require('./config.js'); const events = require('./events.js'); const progress = require('./util/progress.js'); +const uri = require('./util/uri.js'); let fileTokens = {}; diff --git a/client/js/controllers/auth_controller.js b/client/js/controllers/auth_controller.js index 42e6c0a2..2838e531 100644 --- a/client/js/controllers/auth_controller.js +++ b/client/js/controllers/auth_controller.js @@ -2,6 +2,7 @@ const router = require('../router.js'); const api = require('../api.js'); +const uri = require('../util/uri.js'); const topNavigation = require('../models/top_navigation.js'); const LoginView = require('../views/login_view.js'); @@ -21,7 +22,7 @@ class LoginController { api.forget(); api.login(e.detail.name, e.detail.password, e.detail.remember) .then(() => { - const ctx = router.show('/'); + const ctx = router.show(uri.formatClientLink()); ctx.controller.showSuccess('Logged in'); }, error => { this._loginView.showError(error.message); @@ -34,16 +35,16 @@ class LogoutController { constructor() { api.forget(); api.logout(); - const ctx = router.show('/'); + const ctx = router.show(uri.formatClientLink()); ctx.controller.showSuccess('Logged out'); } } module.exports = router => { - router.enter('/login', (ctx, next) => { + router.enter(['login'], (ctx, next) => { ctx.controller = new LoginController(); }); - router.enter('/logout', (ctx, next) => { + router.enter(['logout'], (ctx, next) => { ctx.controller = new LogoutController(); }); }; diff --git a/client/js/controllers/comments_controller.js b/client/js/controllers/comments_controller.js index a7cb76ea..ce886323 100644 --- a/client/js/controllers/comments_controller.js +++ b/client/js/controllers/comments_controller.js @@ -1,7 +1,7 @@ 'use strict'; const api = require('../api.js'); -const misc = require('../util/misc.js'); +const uri = require('../util/uri.js'); const PostList = require('../models/post_list.js'); const topNavigation = require('../models/top_navigation.js'); const PageController = require('../controllers/page_controller.js'); @@ -28,7 +28,7 @@ class CommentsController { getClientUrlForPage: page => { const parameters = Object.assign( {}, ctx.parameters, {page: page}); - return '/comments/' + misc.formatUrlParameters(parameters); + return uri.formatClientLink('comments', parameters); }, requestPage: page => { return PostList.search( @@ -69,7 +69,6 @@ class CommentsController { }; module.exports = router => { - router.enter('/comments/:parameters?', - (ctx, next) => { misc.parseUrlParametersRoute(ctx, next); }, + router.enter(['comments'], (ctx, next) => { new CommentsController(ctx); }); }; diff --git a/client/js/controllers/help_controller.js b/client/js/controllers/help_controller.js index 2010176d..8e65346b 100644 --- a/client/js/controllers/help_controller.js +++ b/client/js/controllers/help_controller.js @@ -12,13 +12,13 @@ class HelpController { } module.exports = router => { - router.enter('/help', (ctx, next) => { + router.enter(['help'], (ctx, next) => { new HelpController(); }); - router.enter('/help/:section', (ctx, next) => { + router.enter(['help', ':section'], (ctx, next) => { new HelpController(ctx.parameters.section); }); - router.enter('/help/:section/:subsection', (ctx, next) => { + router.enter(['help', ':section', ':subsection'], (ctx, next) => { new HelpController(ctx.parameters.section, ctx.parameters.subsection); }); }; diff --git a/client/js/controllers/home_controller.js b/client/js/controllers/home_controller.js index 8619c1ad..b7590a00 100644 --- a/client/js/controllers/home_controller.js +++ b/client/js/controllers/home_controller.js @@ -44,7 +44,7 @@ class HomeController { }; module.exports = router => { - router.enter('/', (ctx, next) => { + router.enter([], (ctx, next) => { ctx.controller = new HomeController(); }); }; diff --git a/client/js/controllers/not_found_controller.js b/client/js/controllers/not_found_controller.js index 1d54b219..66f52e92 100644 --- a/client/js/controllers/not_found_controller.js +++ b/client/js/controllers/not_found_controller.js @@ -12,7 +12,7 @@ class NotFoundController { }; module.exports = router => { - router.enter('*', (ctx, next) => { + router.enter(null, (ctx, next) => { ctx.controller = new NotFoundController(ctx.canonicalPath); }); }; diff --git a/client/js/controllers/password_reset_controller.js b/client/js/controllers/password_reset_controller.js index 3bd77fd5..e0a9801f 100644 --- a/client/js/controllers/password_reset_controller.js +++ b/client/js/controllers/password_reset_controller.js @@ -2,6 +2,7 @@ const router = require('../router.js'); const api = require('../api.js'); +const uri = require('../util/uri.js'); const topNavigation = require('../models/top_navigation.js'); const PasswordResetView = require('../views/password_reset_view.js'); @@ -20,7 +21,7 @@ class PasswordResetController { this._passwordResetView.disableForm(); api.forget(); api.logout(); - api.get('/password-reset/' + e.detail.userNameOrEmail) + api.get(uri.formatApiLink('password-reset', e.detail.userNameOrEmail)) .then(() => { this._passwordResetView.showSuccess( 'E-mail has been sent. To finish the procedure, ' + @@ -37,26 +38,26 @@ class PasswordResetFinishController { api.forget(); api.logout(); let password = null; - api.post('/password-reset/' + name, {token: token}) + api.post(uri.formatApiLink('password-reset', name), {token: token}) .then(response => { password = response.password; return api.login(name, password, false); }).then(() => { - const ctx = router.show('/'); + const ctx = router.show(uri.formatClientLink()); ctx.controller.showSuccess('New password: ' + password); }, error => { - const ctx = router.show('/'); + const ctx = router.show(uri.formatClientLink()); ctx.controller.showError(error.message); }); } } module.exports = router => { - router.enter('/password-reset', (ctx, next) => { + router.enter(['password-reset'], (ctx, next) => { ctx.controller = new PasswordResetController(); }); - router.enter(/\/password-reset\/([^:]+):([^:]+)$/, (ctx, next) => { - ctx.controller = new PasswordResetFinishController( - ctx.parameters[0], ctx.parameters[1]); + router.enter(['password-reset', ':descriptor'], (ctx, next) => { + const [name, token] = ctx.parameters.descriptor.split(':', 2); + ctx.controller = new PasswordResetFinishController(name, token); }); }; diff --git a/client/js/controllers/post_detail_controller.js b/client/js/controllers/post_detail_controller.js index 3930b5e2..47d5430c 100644 --- a/client/js/controllers/post_detail_controller.js +++ b/client/js/controllers/post_detail_controller.js @@ -3,6 +3,7 @@ const router = require('../router.js'); const api = require('../api.js'); const misc = require('../util/misc.js'); +const uri = require('../util/uri.js'); const settings = require('../models/settings.js'); const Post = require('../models/post.js'); const PostList = require('../models/post_list.js'); @@ -55,7 +56,8 @@ class PostDetailController extends BasePostController { misc.disableExitConfirmation(); if (this._id !== e.detail.post.id) { router.replace( - '/post/' + e.detail.post.id + '/' + section, null, false); + uri.formatClientLink('post', e.detail.post.id, section), + null, false); } } @@ -67,7 +69,9 @@ class PostDetailController extends BasePostController { this._installView(e.detail.post, 'merge'); this._view.showSuccess('Post merged.'); router.replace( - '/post/' + e.detail.targetPost.id + '/merge', null, false); + uri.formatClientLink( + 'post', e.detail.targetPost.id, 'merge'), + null, false); }, error => { this._view.showError(error.message); this._view.enableForm(); @@ -77,7 +81,7 @@ class PostDetailController extends BasePostController { module.exports = router => { router.enter( - '/post/:id/merge', + ['post', ':id', 'merge'], (ctx, next) => { ctx.controller = new PostDetailController(ctx, 'merge'); }); diff --git a/client/js/controllers/post_list_controller.js b/client/js/controllers/post_list_controller.js index f7928b07..1e6dd7c8 100644 --- a/client/js/controllers/post_list_controller.js +++ b/client/js/controllers/post_list_controller.js @@ -2,7 +2,7 @@ const api = require('../api.js'); const settings = require('../models/settings.js'); -const misc = require('../util/misc.js'); +const uri = require('../util/uri.js'); const PostList = require('../models/post_list.js'); const topNavigation = require('../models/top_navigation.js'); const PageController = require('../controllers/page_controller.js'); @@ -52,7 +52,7 @@ class PostListController { history.pushState( null, window.title, - '/posts/' + misc.formatUrlParameters(e.detail.parameters)); + uri.formatClientLink('posts', e.detail.parameters)); Object.assign(this._ctx.parameters, e.detail.parameters); this._syncPageController(); } @@ -89,7 +89,7 @@ class PostListController { this._pageController.run({ parameters: this._ctx.parameters, getClientUrlForPage: page => { - return '/posts/' + misc.formatUrlParameters( + return uri.formatClientLink('posts', Object.assign({}, this._ctx.parameters, {page: page})); }, requestPage: page => { @@ -114,7 +114,6 @@ class PostListController { module.exports = router => { router.enter( - '/posts/:parameters(.*)?', - (ctx, next) => { misc.parseUrlParametersRoute(ctx, next); }, + ['posts'], (ctx, next) => { ctx.controller = new PostListController(ctx); }); }; diff --git a/client/js/controllers/post_main_controller.js b/client/js/controllers/post_main_controller.js index 920ce758..ba5f6886 100644 --- a/client/js/controllers/post_main_controller.js +++ b/client/js/controllers/post_main_controller.js @@ -2,6 +2,7 @@ const router = require('../router.js'); const api = require('../api.js'); +const uri = require('../util/uri.js'); const misc = require('../util/misc.js'); const settings = require('../models/settings.js'); const Comment = require('../models/comment.js'); @@ -29,8 +30,8 @@ class PostMainController extends BasePostController { if (parameters.query) { ctx.state.parameters = parameters; const url = editMode ? - '/post/' + ctx.parameters.id + '/edit' : - '/post/' + ctx.parameters.id; + uri.formatClientLink('post', ctx.parameters.id, 'edit') : + uri.formatClientLink('post', ctx.parameters.id); router.replace(url, ctx.state, false); } @@ -124,7 +125,7 @@ class PostMainController extends BasePostController { } _evtMergePost(e) { - router.show('/post/' + e.detail.post.id + '/merge'); + router.show(uri.formatClientLink('post', e.detail.post.id, 'merge')); } _evtDeletePost(e) { @@ -133,7 +134,7 @@ class PostMainController extends BasePostController { e.detail.post.delete() .then(() => { misc.disableExitConfirmation(); - const ctx = router.show('/posts'); + const ctx = router.show(uri.formatClientLink('posts')); ctx.controller.showSuccess('Post deleted.'); }, error => { this._view.sidebarControl.showError(error.message); @@ -244,8 +245,7 @@ class PostMainController extends BasePostController { } module.exports = router => { - router.enter('/post/:id/edit/:parameters(.*)?', - (ctx, next) => { misc.parseUrlParametersRoute(ctx, next); }, + router.enter(['post', ':id', 'edit'], (ctx, next) => { // restore parameters from history state if (ctx.state.parameters) { @@ -254,8 +254,7 @@ module.exports = router => { ctx.controller = new PostMainController(ctx, true); }); router.enter( - '/post/:id/:parameters(.*)?', - (ctx, next) => { misc.parseUrlParametersRoute(ctx, next); }, + ['post', ':id'], (ctx, next) => { // restore parameters from history state if (ctx.state.parameters) { diff --git a/client/js/controllers/post_upload_controller.js b/client/js/controllers/post_upload_controller.js index 45023413..977ac2be 100644 --- a/client/js/controllers/post_upload_controller.js +++ b/client/js/controllers/post_upload_controller.js @@ -2,6 +2,7 @@ const api = require('../api.js'); const router = require('../router.js'); +const uri = require('../util/uri.js'); const misc = require('../util/misc.js'); const progress = require('../util/progress.js'); const topNavigation = require('../models/top_navigation.js'); @@ -61,7 +62,7 @@ class PostUploadController { .then(() => { this._view.clearMessages(); misc.disableExitConfirmation(); - const ctx = router.show('/posts'); + const ctx = router.show(uri.formatClientLink('posts')); ctx.controller.showSuccess('Posts uploaded.'); }, error => { if (error.uploadable) { @@ -149,7 +150,7 @@ class PostUploadController { } module.exports = router => { - router.enter('/upload', (ctx, next) => { + router.enter(['upload'], (ctx, next) => { ctx.controller = new PostUploadController(); }); }; diff --git a/client/js/controllers/settings_controller.js b/client/js/controllers/settings_controller.js index 2b087ebf..224b2059 100644 --- a/client/js/controllers/settings_controller.js +++ b/client/js/controllers/settings_controller.js @@ -22,7 +22,7 @@ class SettingsController { }; module.exports = router => { - router.enter('/settings', (ctx, next) => { + router.enter(['settings'], (ctx, next) => { ctx.controller = new SettingsController(); }); }; diff --git a/client/js/controllers/snapshots_controller.js b/client/js/controllers/snapshots_controller.js index 3aa3f89e..034bf498 100644 --- a/client/js/controllers/snapshots_controller.js +++ b/client/js/controllers/snapshots_controller.js @@ -1,7 +1,7 @@ 'use strict'; const api = require('../api.js'); -const misc = require('../util/misc.js'); +const uri = require('../util/uri.js'); const SnapshotList = require('../models/snapshot_list.js'); const PageController = require('../controllers/page_controller.js'); const topNavigation = require('../models/top_navigation.js'); @@ -25,7 +25,7 @@ class SnapshotsController { getClientUrlForPage: page => { const parameters = Object.assign( {}, ctx.parameters, {page: page}); - return '/history/' + misc.formatUrlParameters(parameters); + return uri.formatClientLink('history', parameters); }, requestPage: page => { return SnapshotList.search('', page, 25); @@ -43,7 +43,6 @@ class SnapshotsController { } module.exports = router => { - router.enter('/history/:parameters?', - (ctx, next) => { misc.parseUrlParametersRoute(ctx, next); }, + router.enter(['history'], (ctx, next) => { ctx.controller = new SnapshotsController(ctx); }); }; diff --git a/client/js/controllers/tag_categories_controller.js b/client/js/controllers/tag_categories_controller.js index 0dd7a77b..dadf8e4a 100644 --- a/client/js/controllers/tag_categories_controller.js +++ b/client/js/controllers/tag_categories_controller.js @@ -51,7 +51,7 @@ class TagCategoriesController { } module.exports = router => { - router.enter('/tag-categories', (ctx, next) => { + router.enter(['tag-categories'], (ctx, next) => { ctx.controller = new TagCategoriesController(ctx, next); }); }; diff --git a/client/js/controllers/tag_controller.js b/client/js/controllers/tag_controller.js index 939b8f04..c1de3a36 100644 --- a/client/js/controllers/tag_controller.js +++ b/client/js/controllers/tag_controller.js @@ -3,6 +3,7 @@ const router = require('../router.js'); const api = require('../api.js'); const misc = require('../util/misc.js'); +const uri = require('../util/uri.js'); const tags = require('../tags.js'); const Tag = require('../models/tag.js'); const topNavigation = require('../models/top_navigation.js'); @@ -61,7 +62,8 @@ class TagController { misc.disableExitConfirmation(); if (this._name !== e.detail.tag.names[0]) { router.replace( - '/tag/' + e.detail.tag.names[0] + '/' + section, null, false); + uri.formatClientLink('tag', e.detail.tag.names[0], section), + null, false); } } @@ -99,7 +101,8 @@ class TagController { this._view.showSuccess('Tag merged.'); this._view.enableForm(); router.replace( - '/tag/' + e.detail.targetTagName + '/merge', null, false); + uri.formatClientLink('tag', e.detail.targetTagName, 'merge'), + null, false); }, error => { this._view.showError(error.message); this._view.enableForm(); @@ -111,7 +114,7 @@ class TagController { this._view.disableForm(); e.detail.tag.delete() .then(() => { - const ctx = router.show('/tags/'); + const ctx = router.show(uri.formatClientLink('tags')); ctx.controller.showSuccess('Tag deleted.'); }, error => { this._view.showError(error.message); @@ -121,16 +124,16 @@ class TagController { } module.exports = router => { - router.enter('/tag/:name(.+?)/edit', (ctx, next) => { + router.enter(['tag', ':name', 'edit'], (ctx, next) => { ctx.controller = new TagController(ctx, 'edit'); }); - router.enter('/tag/:name(.+?)/merge', (ctx, next) => { + router.enter(['tag', ':name', 'merge'], (ctx, next) => { ctx.controller = new TagController(ctx, 'merge'); }); - router.enter('/tag/:name(.+?)/delete', (ctx, next) => { + router.enter(['tag', ':name', 'delete'], (ctx, next) => { ctx.controller = new TagController(ctx, 'delete'); }); - router.enter('/tag/:name(.+)', (ctx, next) => { + router.enter(['tag', ':name'], (ctx, next) => { ctx.controller = new TagController(ctx, 'summary'); }); }; diff --git a/client/js/controllers/tag_list_controller.js b/client/js/controllers/tag_list_controller.js index d61d8f71..2b485d2e 100644 --- a/client/js/controllers/tag_list_controller.js +++ b/client/js/controllers/tag_list_controller.js @@ -1,7 +1,7 @@ 'use strict'; const api = require('../api.js'); -const misc = require('../util/misc.js'); +const uri = require('../util/uri.js'); const TagList = require('../models/tag_list.js'); const topNavigation = require('../models/top_navigation.js'); const PageController = require('../controllers/page_controller.js'); @@ -49,7 +49,7 @@ class TagListController { history.pushState( null, window.title, - '/tags/' + misc.formatUrlParameters(e.detail.parameters)); + uri.formatClientLink('tags', e.detail.parameters)); Object.assign(this._ctx.parameters, e.detail.parameters); this._syncPageController(); } @@ -60,7 +60,7 @@ class TagListController { getClientUrlForPage: page => { const parameters = Object.assign( {}, this._ctx.parameters, {page: page}); - return '/tags/' + misc.formatUrlParameters(parameters); + return uri.formatClientLink('tags', parameters); }, requestPage: page => { return TagList.search( @@ -75,7 +75,6 @@ class TagListController { module.exports = router => { router.enter( - '/tags/:parameters(.*)?', - (ctx, next) => { misc.parseUrlParametersRoute(ctx, next); }, + ['tags'], (ctx, next) => { ctx.controller = new TagListController(ctx); }); }; diff --git a/client/js/controllers/user_controller.js b/client/js/controllers/user_controller.js index e9c1fb39..46020f37 100644 --- a/client/js/controllers/user_controller.js +++ b/client/js/controllers/user_controller.js @@ -2,6 +2,7 @@ const router = require('../router.js'); const api = require('../api.js'); +const uri = require('../util/uri.js'); const misc = require('../util/misc.js'); const config = require('../config.js'); const views = require('../util/views.js'); @@ -77,7 +78,8 @@ class UserController { misc.disableExitConfirmation(); if (this._name !== e.detail.user.name) { router.replace( - '/user/' + e.detail.user.name + '/' + section, null, false); + uri.formatClientLink('user', e.detail.user.name, section), + null, false); } } @@ -135,10 +137,10 @@ class UserController { api.logout(); } if (api.hasPrivilege('users:list')) { - const ctx = router.show('/users'); + const ctx = router.show(uri.formatClientLink('users')); ctx.controller.showSuccess('Account deleted.'); } else { - const ctx = router.show('/'); + const ctx = router.show(uri.formatClientLink()); ctx.controller.showSuccess('Account deleted.'); } }, error => { @@ -149,13 +151,13 @@ class UserController { } module.exports = router => { - router.enter('/user/:name', (ctx, next) => { + router.enter(['user', ':name'], (ctx, next) => { ctx.controller = new UserController(ctx, 'summary'); }); - router.enter('/user/:name/edit', (ctx, next) => { + router.enter(['user', ':name', 'edit'], (ctx, next) => { ctx.controller = new UserController(ctx, 'edit'); }); - router.enter('/user/:name/delete', (ctx, next) => { + router.enter(['user', ':name', 'delete'], (ctx, next) => { ctx.controller = new UserController(ctx, 'delete'); }); }; diff --git a/client/js/controllers/user_list_controller.js b/client/js/controllers/user_list_controller.js index 98e70031..13aec009 100644 --- a/client/js/controllers/user_list_controller.js +++ b/client/js/controllers/user_list_controller.js @@ -1,7 +1,7 @@ 'use strict'; const api = require('../api.js'); -const misc = require('../util/misc.js'); +const uri = require('../util/uri.js'); const UserList = require('../models/user_list.js'); const topNavigation = require('../models/top_navigation.js'); const PageController = require('../controllers/page_controller.js'); @@ -41,7 +41,7 @@ class UserListController { history.pushState( null, window.title, - '/users/' + misc.formatUrlParameters(e.detail.parameters)); + uri.formatClientLink('users', e.detail.parameters)); Object.assign(this._ctx.parameters, e.detail.parameters); this._syncPageController(); } @@ -52,7 +52,7 @@ class UserListController { getClientUrlForPage: page => { const parameters = Object.assign( {}, this._ctx.parameters, {page: page}); - return '/users/' + misc.formatUrlParameters(parameters); + return uri.formatClientLink('users', parameters); }, requestPage: page => { return UserList.search(this._ctx.parameters.query, page); @@ -69,7 +69,6 @@ class UserListController { module.exports = router => { router.enter( - '/users/:parameters(.*)?', - (ctx, next) => { misc.parseUrlParametersRoute(ctx, next); }, + ['users'], (ctx, next) => { ctx.controller = new UserListController(ctx); }); }; diff --git a/client/js/controllers/user_registration_controller.js b/client/js/controllers/user_registration_controller.js index 47165d60..7d822380 100644 --- a/client/js/controllers/user_registration_controller.js +++ b/client/js/controllers/user_registration_controller.js @@ -2,6 +2,7 @@ const router = require('../router.js'); const api = require('../api.js'); +const uri = require('../util/uri.js'); const User = require('../models/user.js'); const topNavigation = require('../models/top_navigation.js'); const RegistrationView = require('../views/registration_view.js'); @@ -32,7 +33,7 @@ class UserRegistrationController { api.forget(); return api.login(e.detail.name, e.detail.password, false); }).then(() => { - const ctx = router.show('/'); + const ctx = router.show(uri.formatClientLink()); ctx.controller.showSuccess('Welcome aboard!'); }, error => { this._view.showError(error.message); @@ -42,7 +43,7 @@ class UserRegistrationController { } module.exports = router => { - router.enter('/register', (ctx, next) => { + router.enter(['register'], (ctx, next) => { new UserRegistrationController(); }); }; diff --git a/client/js/controls/tag_input_control.js b/client/js/controls/tag_input_control.js index 2912b0fc..ee864281 100644 --- a/client/js/controls/tag_input_control.js +++ b/client/js/controls/tag_input_control.js @@ -3,6 +3,7 @@ const api = require('../api.js'); const tags = require('../tags.js'); const misc = require('../util/misc.js'); +const uri = require('../util/uri.js'); const settings = require('../models/settings.js'); const events = require('../events.js'); const views = require('../util/views.js'); @@ -308,7 +309,7 @@ class TagInputControl extends events.EventTarget { tagLinkNode.classList.add(className); } tagLinkNode.setAttribute( - 'href', '/tag/' + encodeURIComponent(tagName)); + 'href', uri.formatClientLink('tag', tagName)); const tagIconNode = document.createElement('i'); tagIconNode.classList.add('fa'); tagIconNode.classList.add('fa-tag'); @@ -319,7 +320,7 @@ class TagInputControl extends events.EventTarget { searchLinkNode.classList.add(className); } searchLinkNode.setAttribute( - 'href', '/posts/query=' + encodeURIComponent(tagName)); + 'href', uri.formatClientLink('posts', {query: tagName})); searchLinkNode.textContent = tagName + ' '; searchLinkNode.addEventListener('click', e => { e.preventDefault(); @@ -360,7 +361,9 @@ class TagInputControl extends events.EventTarget { if (!browsingSettings.tagSuggestions) { return; } - api.get('/tag-siblings/' + tag.names[0], {noProgress: true}) + api.get( + uri.formatApiLink('tag-siblings', tag.names[0]), + {noProgress: true}) .then(response => { return Promise.resolve(response.results); }, response => { diff --git a/client/js/main.js b/client/js/main.js index 27c14e88..2308ee90 100644 --- a/client/js/main.js +++ b/client/js/main.js @@ -8,7 +8,7 @@ const router = require('./router.js'); history.scrollRestoration = 'manual'; router.exit( - /.*/, + null, (ctx, next) => { ctx.state.scrollX = window.scrollX; ctx.state.scrollY = window.scrollY; @@ -20,7 +20,7 @@ router.exit( const mousetrap = require('mousetrap'); router.enter( - /.*/, + null, (ctx, next) => { mousetrap.reset(); next(); diff --git a/client/js/models/comment.js b/client/js/models/comment.js index 623fd7e3..e10e83cb 100644 --- a/client/js/models/comment.js +++ b/client/js/models/comment.js @@ -1,6 +1,7 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const events = require('../events.js'); class Comment extends events.EventTarget { @@ -38,9 +39,9 @@ class Comment extends events.EventTarget { text: this._text, }; let promise = this._id ? - api.put('/comment/' + this._id, detail) : - api.post( - '/comments', Object.assign({postId: this._postId}, detail)); + api.put(uri.formatApiLink('comment', this.id), detail) : + api.post(uri.formatApiLink('comments'), + Object.assign({postId: this._postId}, detail)); return promise.then(response => { this._updateFromResponse(response); @@ -55,7 +56,7 @@ class Comment extends events.EventTarget { delete() { return api.delete( - '/comment/' + this._id, + uri.formatApiLink('comment', this.id), {version: this._version}) .then(response => { this.dispatchEvent(new CustomEvent('delete', { @@ -68,7 +69,9 @@ class Comment extends events.EventTarget { } setScore(score) { - return api.put('/comment/' + this._id + '/score', {score: score}) + return api.put( + uri.formatApiLink('comment', this.id, 'score'), + {score: score}) .then(response => { this._updateFromResponse(response); this.dispatchEvent(new CustomEvent('changeScore', { diff --git a/client/js/models/info.js b/client/js/models/info.js index 3b196dd4..35ba867a 100644 --- a/client/js/models/info.js +++ b/client/js/models/info.js @@ -1,11 +1,12 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const Post = require('./post.js'); class Info { static get() { - return api.get('/info') + return api.get(uri.formatApiLink('info')) .then(response => { return Promise.resolve(Object.assign( {}, diff --git a/client/js/models/post.js b/client/js/models/post.js index 8b767a2e..71f3eb98 100644 --- a/client/js/models/post.js +++ b/client/js/models/post.js @@ -1,6 +1,7 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const tags = require('../tags.js'); const events = require('../events.js'); const NoteList = require('./note_list.js'); @@ -69,7 +70,9 @@ class Post extends events.EventTarget { static reverseSearch(content) { let apiPromise = api.post( - '/posts/reverse-search', {}, {content: content}); + uri.formatApiLink('posts', 'reverse-search'), + {}, + {content: content}); let returnedPromise = apiPromise .then(response => { if (response.exactPost) { @@ -85,7 +88,7 @@ class Post extends events.EventTarget { } static get(id) { - return api.get('/post/' + id) + return api.get(uri.formatApiLink('post', id)) .then(response => { return Promise.resolve(Post.fromResponse(response)); }); @@ -149,8 +152,8 @@ class Post extends events.EventTarget { } let apiPromise = this._id ? - api.put('/post/' + this._id, detail, files) : - api.post('/posts', detail, files); + api.put(uri.formatApiLink('post', this.id), detail, files) : + api.post(uri.formatApiLink('posts'), detail, files); return apiPromise.then(response => { this._updateFromResponse(response); @@ -176,14 +179,18 @@ class Post extends events.EventTarget { } feature() { - return api.post('/featured-post', {id: this._id}) + return api.post( + uri.formatApiLink('featured-post'), + {id: this._id}) .then(response => { return Promise.resolve(); }); } delete() { - return api.delete('/post/' + this._id, {version: this._version}) + return api.delete( + uri.formatApiLink('post', this.id), + {version: this._version}) .then(response => { this.dispatchEvent(new CustomEvent('delete', { detail: { @@ -195,9 +202,9 @@ class Post extends events.EventTarget { } merge(targetId, useOldContent) { - return api.get('/post/' + encodeURIComponent(targetId)) + return api.get(uri.formatApiLink('post', targetId)) .then(response => { - return api.post('/post-merge/', { + return api.post(uri.formatApiLink('post-merge'), { removeVersion: this._version, remove: this._id, mergeToVersion: response.version, @@ -216,7 +223,9 @@ class Post extends events.EventTarget { } setScore(score) { - return api.put('/post/' + this._id + '/score', {score: score}) + return api.put( + uri.formatApiLink('post', this.id, 'score'), + {score: score}) .then(response => { const prevFavorite = this._ownFavorite; this._updateFromResponse(response); @@ -237,7 +246,7 @@ class Post extends events.EventTarget { } addToFavorites() { - return api.post('/post/' + this.id + '/favorite') + return api.post(uri.formatApiLink('post', this.id, 'favorite')) .then(response => { const prevScore = this._ownScore; this._updateFromResponse(response); @@ -258,7 +267,7 @@ class Post extends events.EventTarget { } removeFromFavorites() { - return api.delete('/post/' + this.id + '/favorite') + return api.delete(uri.formatApiLink('post', this.id, 'favorite')) .then(response => { const prevScore = this._ownScore; this._updateFromResponse(response); diff --git a/client/js/models/post_list.js b/client/js/models/post_list.js index 312afc86..2af40f58 100644 --- a/client/js/models/post_list.js +++ b/client/js/models/post_list.js @@ -1,29 +1,32 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const AbstractList = require('./abstract_list.js'); const Post = require('./post.js'); class PostList extends AbstractList { static getAround(id, searchQuery) { - const url = - `/post/${id}/around?fields=id` + - `&query=${encodeURIComponent(searchQuery)}`; - return api.get(url); + return api.get( + uri.formatApiLink( + 'post', id, 'around', {query: searchQuery, fields: 'id'})); } static search(text, page, pageSize, fields) { - const url = - `/posts/?query=${encodeURIComponent(text)}` + - `&page=${page}` + - `&pageSize=${pageSize}` + - `&fields=${fields.join(',')}`; - return api.get(url).then(response => { - return Promise.resolve(Object.assign( - {}, - response, - {results: PostList.fromResponse(response.results)})); - }); + return api.get( + uri.formatApiLink( + 'posts', { + query: text, + page: page, + pageSize: pageSize, + fields: fields.join(','), + })) + .then(response => { + return Promise.resolve(Object.assign( + {}, + response, + {results: PostList.fromResponse(response.results)})); + }); } } diff --git a/client/js/models/snapshot_list.js b/client/js/models/snapshot_list.js index 22348f4a..a23850a3 100644 --- a/client/js/models/snapshot_list.js +++ b/client/js/models/snapshot_list.js @@ -1,21 +1,20 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const AbstractList = require('./abstract_list.js'); const Snapshot = require('./snapshot.js'); class SnapshotList extends AbstractList { static search(text, page, pageSize) { - const url = - `/snapshots/?query=${encodeURIComponent(text)}` + - `&page=${page}` + - `&pageSize=${pageSize}`; - return api.get(url).then(response => { - return Promise.resolve(Object.assign( - {}, - response, - {results: SnapshotList.fromResponse(response.results)})); - }); + return api.get(uri.formatApiLink( + 'snapshots', {query: text, page: page, pageSize: pageSize})) + .then(response => { + return Promise.resolve(Object.assign( + {}, + response, + {results: SnapshotList.fromResponse(response.results)})); + }); } } diff --git a/client/js/models/tag.js b/client/js/models/tag.js index 07bbf7ff..1a156936 100644 --- a/client/js/models/tag.js +++ b/client/js/models/tag.js @@ -1,6 +1,7 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const events = require('../events.js'); const misc = require('../util/misc.js'); @@ -33,7 +34,7 @@ class Tag extends events.EventTarget { } static get(name) { - return api.get('/tag/' + encodeURIComponent(name)) + return api.get(uri.formatApiLink('tag', name)) .then(response => { return Promise.resolve(Tag.fromResponse(response)); }); @@ -60,8 +61,8 @@ class Tag extends events.EventTarget { } let promise = this._origName ? - api.put('/tag/' + encodeURIComponent(this._origName), detail) : - api.post('/tags', detail); + api.put(uri.formatApiLink('tag', this._origName), detail) : + api.post(uri.formatApiLink('tags'), detail); return promise .then(response => { this._updateFromResponse(response); @@ -75,9 +76,9 @@ class Tag extends events.EventTarget { } merge(targetName) { - return api.get('/tag/' + encodeURIComponent(targetName)) + return api.get(uri.formatApiLink('tag', targetName)) .then(response => { - return api.post('/tag-merge/', { + return api.post(uri.formatApiLink('tag-merge'), { removeVersion: this._version, remove: this._origName, mergeToVersion: response.version, @@ -96,7 +97,7 @@ class Tag extends events.EventTarget { delete() { return api.delete( - '/tag/' + encodeURIComponent(this._origName), + uri.formatApiLink('tag', this._origName), {version: this._version}) .then(response => { this.dispatchEvent(new CustomEvent('delete', { diff --git a/client/js/models/tag_category.js b/client/js/models/tag_category.js index cc08d345..04bd8fe6 100644 --- a/client/js/models/tag_category.js +++ b/client/js/models/tag_category.js @@ -1,6 +1,7 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const events = require('../events.js'); class TagCategory extends events.EventTarget { @@ -45,9 +46,9 @@ class TagCategory extends events.EventTarget { let promise = this._origName ? api.put( - '/tag-category/' + encodeURIComponent(this._origName), + uri.formatApiLink('tag-category', this._origName), detail) : - api.post('/tag-categories', detail); + api.post(uri.formatApiLink('tag-categories'), detail); return promise .then(response => { @@ -63,7 +64,7 @@ class TagCategory extends events.EventTarget { delete() { return api.delete( - '/tag-category/' + encodeURIComponent(this._origName), + uri.formatApiLink('tag-category', this._origName), {version: this._version}) .then(response => { this.dispatchEvent(new CustomEvent('delete', { diff --git a/client/js/models/tag_category_list.js b/client/js/models/tag_category_list.js index 812d6959..6c1182fb 100644 --- a/client/js/models/tag_category_list.js +++ b/client/js/models/tag_category_list.js @@ -1,6 +1,7 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const AbstractList = require('./abstract_list.js'); const TagCategory = require('./tag_category.js'); @@ -26,12 +27,13 @@ class TagCategoryList extends AbstractList { } static get() { - return api.get('/tag-categories/').then(response => { - return Promise.resolve(Object.assign( - {}, - response, - {results: TagCategoryList.fromResponse(response.results)})); - }); + return api.get(uri.formatApiLink('tag-categories')) + .then(response => { + return Promise.resolve(Object.assign( + {}, + response, + {results: TagCategoryList.fromResponse(response.results)})); + }); } get defaultCategory() { @@ -54,7 +56,10 @@ class TagCategoryList extends AbstractList { if (this._defaultCategory !== this._origDefaultCategory) { promises.push( api.put( - `/tag-category/${this._defaultCategory.name}/default`)); + uri.formatApiLink( + 'tag-category', + this._defaultCategory.name, + 'default'))); } return Promise.all(promises) diff --git a/client/js/models/tag_list.js b/client/js/models/tag_list.js index 268bcd80..d480745c 100644 --- a/client/js/models/tag_list.js +++ b/client/js/models/tag_list.js @@ -1,22 +1,26 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const AbstractList = require('./abstract_list.js'); const Tag = require('./tag.js'); class TagList extends AbstractList { static search(text, page, pageSize, fields) { - const url = - `/tags/?query=${encodeURIComponent(text)}` + - `&page=${page}` + - `&pageSize=${pageSize}` + - `&fields=${fields.join(',')}`; - return api.get(url).then(response => { - return Promise.resolve(Object.assign( - {}, - response, - {results: TagList.fromResponse(response.results)})); - }); + return api.get( + uri.formatApiLink( + 'tags', { + query: text, + page: page, + pageSize: pageSize, + fields: fields.join(','), + })) + .then(response => { + return Promise.resolve(Object.assign( + {}, + response, + {results: TagList.fromResponse(response.results)})); + }); } } diff --git a/client/js/models/user.js b/client/js/models/user.js index ab19bef3..abb5f57f 100644 --- a/client/js/models/user.js +++ b/client/js/models/user.js @@ -1,6 +1,7 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const events = require('../events.js'); class User extends events.EventTarget { @@ -40,7 +41,7 @@ class User extends events.EventTarget { } static get(name) { - return api.get('/user/' + encodeURIComponent(name)) + return api.get(uri.formatApiLink('user', name)) .then(response => { return Promise.resolve(User.fromResponse(response)); }); @@ -73,10 +74,8 @@ class User extends events.EventTarget { let promise = this._orig._name ? api.put( - '/user/' + encodeURIComponent(this._orig._name), - detail, - files) : - api.post('/users', detail, files); + uri.formatApiLink('user', this._orig._name), detail, files) : + api.post(uri.formatApiLink('users'), detail, files); return promise .then(response => { @@ -92,7 +91,7 @@ class User extends events.EventTarget { delete() { return api.delete( - '/user/' + encodeURIComponent(this._orig._name), + uri.formatApiLink('user', this._orig._name), {version: this._version}) .then(response => { this.dispatchEvent(new CustomEvent('delete', { diff --git a/client/js/models/user_list.js b/client/js/models/user_list.js index 4ca430d5..ff3e27cd 100644 --- a/client/js/models/user_list.js +++ b/client/js/models/user_list.js @@ -1,20 +1,21 @@ 'use strict'; const api = require('../api.js'); +const uri = require('../util/uri.js'); const AbstractList = require('./abstract_list.js'); const User = require('./user.js'); class UserList extends AbstractList { static search(text, page) { - const url = - `/users/?query=${encodeURIComponent(text)}` + - `&page=${page}&pageSize=30`; - return api.get(url).then(response => { - return Promise.resolve(Object.assign( - {}, - response, - {results: UserList.fromResponse(response.results)})); - }); + return api.get( + uri.formatApiLink( + 'users', {query: text, page: page, pageSize: 30})) + .then(response => { + return Promise.resolve(Object.assign( + {}, + response, + {results: UserList.fromResponse(response.results)})); + }); } } diff --git a/client/js/router.js b/client/js/router.js index 570fb3ef..9bff064f 100644 --- a/client/js/router.js +++ b/client/js/router.js @@ -1,6 +1,7 @@ 'use strict'; // modified page.js by visionmedia +// - changed regexes to components // - removed unused crap // - refactored to classes // - simplified method chains @@ -9,19 +10,12 @@ // - rename .save() to .replaceState() // - offer .url -const pathToRegexp = require('path-to-regexp'); const clickEvent = document.ontouchstart ? 'touchstart' : 'click'; +const uri = require('./util/uri.js'); let location = window.history.location || window.location; const base = ''; -function _decodeURLEncodedURIComponent(val) { - if (typeof val !== 'string') { - return val; - } - return decodeURIComponent(val.replace(/\+/g, ' ')); -} - function _isSameOrigin(href) { let origin = location.protocol + '//' + location.hostname; if (location.port) { @@ -55,11 +49,28 @@ class Context { }; class Route { - constructor(path, options) { - options = options || {}; - this.path = (path === '*') ? '(.*)' : path; + constructor(path) { this.method = 'GET'; - this.regexp = pathToRegexp(this.path, this.keys = [], options); + this.path = path; + + this.parameterNames = []; + if (this.path === null) { + this.regex = /.*/; + } else { + let parts = []; + for (let component of this.path) { + if (component[0] === ':') { + parts.push('([^/]+)'); + this.parameterNames.push(component.substr(1)); + } else { // assert [a-z]+ + parts.push(component); + } + } + let regexString = '^/' + parts.join('/'); + regexString += '(?:/*|/((?:(?:[a-z]+=[^/]+);)*(?:[a-z]+=[^/]+)))$'; + this.parameterNames.push('variable'); + this.regex = new RegExp(regexString); + } } middleware(fn) { @@ -72,24 +83,39 @@ class Route { } match(path, parameters) { - const keys = this.keys; const qsIndex = path.indexOf('?'); const pathname = ~qsIndex ? path.slice(0, qsIndex) : path; - const m = this.regexp.exec(pathname); + const match = this.regex.exec(pathname); - if (!m) { + if (!match) { return false; } - for (let i = 1, len = m.length; i < len; ++i) { - const key = keys[i - 1]; - const val = _decodeURLEncodedURIComponent(m[i]); - if (val !== undefined || - !(hasOwnProperty.call(parameters, key.name))) { - parameters[key.name] = val; + try { + for (let i = 1; i < match.length; i++) { + const name = this.parameterNames[i - 1]; + const value = match[i]; + if (value === undefined) { + continue; + } + + if (name === 'variable') { + for (let word of (value || '').split(/;/)) { + const [key, subvalue] = word.split(/=/, 2); + parameters[key] = uri.unescapeParam(subvalue); + } + } else { + parameters[name] = uri.unescapeParam(value); + } } + } catch (e) { + return false; } + // XXX: it is very unfitting place for this + parameters.query = parameters.query || ''; + parameters.page = parseInt(parameters.page || '1'); + return true; } }; diff --git a/client/js/util/misc.js b/client/js/util/misc.js index 3037b103..d825d996 100644 --- a/client/js/util/misc.js +++ b/client/js/util/misc.js @@ -1,6 +1,7 @@ 'use strict'; const markdown = require('./markdown.js'); +const uri = require('./uri.js'); function decamelize(str, sep) { sep = sep === undefined ? '-' : sep; @@ -99,44 +100,10 @@ function formatInlineMarkdown(text) { return markdown.formatInlineMarkdown(text); } -function formatUrlParameters(dict) { - let result = []; - for (let key of Object.keys(dict)) { - const value = dict[key]; - if (key === 'parameters') { - continue; - } - if (value) { - result.push(`${key}=${encodeURIComponent(value)}`); - } - } - return result.join(';'); -} - function splitByWhitespace(str) { return str.split(/\s+/).filter(s => s); } -function parseUrlParameters(query) { - let result = {}; - for (let word of (query || '').split(/;/)) { - const [key, value] = word.split(/=/, 2); - result[key] = value; - } - result.query = result.query || ''; - result.page = parseInt(result.page || '1'); - return result; -} - -function parseUrlParametersRoute(ctx, next) { - // ctx.parameters = {"user":...,"action":...} from /users/:user/:action - // ctx.parameters.parameters = value of :parameters as per /url/:parameters - Object.assign( - ctx.parameters, - parseUrlParameters(ctx.parameters.parameters)); - next(); -} - function unindent(callSite, ...args) { function format(str) { let size = -1; @@ -232,9 +199,6 @@ function dataURItoBlob(dataURI) { module.exports = { range: range, - formatUrlParameters: formatUrlParameters, - parseUrlParameters: parseUrlParameters, - parseUrlParametersRoute: parseUrlParametersRoute, formatRelativeTime: formatRelativeTime, formatFileSize: formatFileSize, formatMarkdown: formatMarkdown, diff --git a/client/js/util/uri.js b/client/js/util/uri.js new file mode 100644 index 00000000..52e90a06 --- /dev/null +++ b/client/js/util/uri.js @@ -0,0 +1,62 @@ +'use strict'; + +function formatApiLink(...values) { + let parts = []; + for (let value of values) { + if (value.constructor === Object) { + // assert this is the last piece + let variableParts = []; + for (let key of Object.keys(value)) { + if (value[key]) { + variableParts.push( + key + '=' + encodeURIComponent(value[key].toString())); + } + } + if (variableParts.length) { + parts.push('?' + variableParts.join('&')); + } + break; + } else { + parts.push(encodeURIComponent(value.toString())); + } + } + return '/' + parts.join('/'); +} + +function escapeParam(text) { + return encodeURIComponent(text).replace(/%/g, '$'); +} + +function unescapeParam(text) { + return decodeURIComponent(text.replace(/\$/g, '%')); +} + +function formatClientLink(...values) { + let parts = []; + for (let value of values) { + if (value.constructor === Object) { + // assert this is the last piece + let variableParts = []; + for (let key of Object.keys(value)) { + if (value[key]) { + variableParts.push( + key + '=' + escapeParam(value[key].toString())); + } + } + if (variableParts.length) { + parts.push(variableParts.join(';')); + } + break; + } else { + parts.push(escapeParam(value.toString())); + } + } + return '/' + parts.join('/'); +} + +module.exports = { + formatClientLink: formatClientLink, + formatApiLink: formatApiLink, + escapeParam: escapeParam, + unescapeParam: unescapeParam, +}; diff --git a/client/js/util/views.js b/client/js/util/views.js index fbf00a3b..7f062599 100644 --- a/client/js/util/views.js +++ b/client/js/util/views.js @@ -6,6 +6,7 @@ const templates = require('../templates.js'); const tags = require('../tags.js'); const domParser = new DOMParser(); const misc = require('./misc.js'); +const uri = require('./uri.js'); function _imbueId(options) { if (!options.id) { @@ -152,19 +153,15 @@ function makeNumericInput(options) { } function getPostUrl(id, parameters) { - let url = '/post/' + encodeURIComponent(id); - if (parameters && parameters.query) { - url += '/query=' + encodeURIComponent(parameters.query); - } - return url; + return uri.formatClientLink( + 'post', id, + parameters ? {query: parameters.query} : {}); } function getPostEditUrl(id, parameters) { - let url = '/post/' + encodeURIComponent(id) + '/edit'; - if (parameters && parameters.query) { - url += '/query=' + encodeURIComponent(parameters.query); - } - return url; + return uri.formatClientLink( + 'post', id, 'edit', + parameters ? {query: parameters.query} : {}); } function makePostLink(id, includeHash) { @@ -175,7 +172,7 @@ function makePostLink(id, includeHash) { return api.hasPrivilege('posts:view') ? makeElement( 'a', - {'href': '/post/' + encodeURIComponent(id)}, + {href: uri.formatClientLink('post', id)}, misc.escapeHtml(text)) : misc.escapeHtml(text); } @@ -191,13 +188,13 @@ function makeTagLink(name, includeHash) { makeElement( 'a', { - 'href': '/tag/' + encodeURIComponent(name), - 'class': misc.makeCssName(category, 'tag'), + href: uri.formatClientLink('tag', name), + class: misc.makeCssName(category, 'tag'), }, misc.escapeHtml(text)) : makeElement( 'span', - {'class': misc.makeCssName(category, 'tag')}, + {class: misc.makeCssName(category, 'tag')}, misc.escapeHtml(text)); } @@ -206,7 +203,7 @@ function makeUserLink(user) { text += user && user.name ? misc.escapeHtml(user.name) : 'Anonymous'; const link = user && api.hasPrivilege('users:view') ? makeElement( - 'a', {'href': '/user/' + encodeURIComponent(user.name)}, text) : + 'a', {href: uri.formatClientLink('user', user.name)}, text) : text; return makeElement('span', {class: 'user'}, link); } @@ -385,6 +382,7 @@ function getTemplate(templatePath) { makeElement: makeElement, makeCssName: misc.makeCssName, makeNumericInput: makeNumericInput, + formatClientLink: uri.formatClientLink }); return htmlToDom(templateFactory(ctx)); }; diff --git a/client/js/views/home_view.js b/client/js/views/home_view.js index af3d2ddc..00ffc7d1 100644 --- a/client/js/views/home_view.js +++ b/client/js/views/home_view.js @@ -1,7 +1,7 @@ 'use strict'; const router = require('../router.js'); -const misc = require('../util/misc.js'); +const uri = require('../util/uri.js'); const views = require('../util/views.js'); const PostContentControl = require('../controls/post_content_control.js'); const PostNotesOverlayControl @@ -88,7 +88,7 @@ class HomeView { _evtFormSubmit(e) { e.preventDefault(); this._searchInputNode.blur(); - router.show('/posts/' + misc.formatUrlParameters({ + router.show(uri.formatClientLink('posts', { query: this._searchInputNode.value})); } } diff --git a/client/js/views/manual_page_view.js b/client/js/views/manual_page_view.js index 94fc70b9..e3767c0d 100644 --- a/client/js/views/manual_page_view.js +++ b/client/js/views/manual_page_view.js @@ -2,7 +2,6 @@ const router = require('../router.js'); const keyboard = require('../util/keyboard.js'); -const misc = require('../util/misc.js'); const views = require('../util/views.js'); const holderTemplate = views.getTemplate('manual-pager'); diff --git a/client/js/views/post_main_view.js b/client/js/views/post_main_view.js index 7c115dbd..452d63f9 100644 --- a/client/js/views/post_main_view.js +++ b/client/js/views/post_main_view.js @@ -2,6 +2,7 @@ const router = require('../router.js'); const views = require('../util/views.js'); +const uri = require('../util/uri.js'); const keyboard = require('../util/keyboard.js'); const PostContentControl = require('../controls/post_content_control.js'); const PostNotesOverlayControl = @@ -61,19 +62,19 @@ class PostMainView { keyboard.bind('e', () => { if (ctx.editMode) { - router.show('/post/' + ctx.post.id); + router.show(uri.formatClientLink('post', ctx.post.id)); } else { - router.show('/post/' + ctx.post.id + '/edit'); + router.show(uri.formatClientLink('post', ctx.post.id, 'edit')); } }); keyboard.bind(['a', 'left'], () => { if (ctx.prevPostId) { - router.show('/post/' + ctx.prevPostId); + router.show(uri.formatClientLink('post', ctx.prevPostId)); } }); keyboard.bind(['d', 'right'], () => { if (ctx.nextPostId) { - router.show('/post/' + ctx.nextPostId); + router.show(uri.formatClientLink('post', ctx.nextPostId)); } }); } diff --git a/client/js/views/users_header_view.js b/client/js/views/users_header_view.js index e4f1c149..08b7620c 100644 --- a/client/js/views/users_header_view.js +++ b/client/js/views/users_header_view.js @@ -1,7 +1,6 @@ 'use strict'; const events = require('../events.js'); -const misc = require('../util/misc.js'); const search = require('../util/search.js'); const views = require('../util/views.js'); diff --git a/client/package.json b/client/package.json index add6de7f..3999e32b 100644 --- a/client/package.json +++ b/client/package.json @@ -22,7 +22,6 @@ "merge": "^1.2.0", "mousetrap": "^1.5.3", "nprogress": "^0.2.0", - "path-to-regexp": "^1.5.1", "stylus": "^0.54.2", "superagent": "^1.8.3", "uglify-js": "git://github.com/mishoo/UglifyJS2.git#harmony", From 6b42d787a7a591a882792d370622c81202531e26 Mon Sep 17 00:00:00 2001 From: rr- Date: Sat, 21 Jan 2017 00:12:28 +0100 Subject: [PATCH 007/159] server: fix problems with escaping --- INSTALL.md | 5 +++-- client/js/util/uri.js | 4 ++-- server/szurubooru/rest/app.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index c4ccdcc3..1b3ebbae 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -159,13 +159,14 @@ the one in the `config.yaml`, so that client knows how to access the backend! server { listen 80; server_name great.dude; - merge_slashes off; # to support post tags such as /// location ~ ^/api$ { return 302 /api/; } location ~ ^/api/(.*)$ { - proxy_pass http://127.0.0.1:6666/$1$is_args$args; + if ($request_uri ~* "/api/(.*)") { # preserve PATH_INFO as-is + proxy_pass http://127.0.0.1:6666/$1; + } } location / { root /home/rr-/src/maintained/szurubooru/client/public; diff --git a/client/js/util/uri.js b/client/js/util/uri.js index 52e90a06..a0482280 100644 --- a/client/js/util/uri.js +++ b/client/js/util/uri.js @@ -24,11 +24,11 @@ function formatApiLink(...values) { } function escapeParam(text) { - return encodeURIComponent(text).replace(/%/g, '$'); + return encodeURIComponent(text); } function unescapeParam(text) { - return decodeURIComponent(text.replace(/\$/g, '%')); + return decodeURIComponent(text); } function formatClientLink(...values) { diff --git a/server/szurubooru/rest/app.py b/server/szurubooru/rest/app.py index 0823e6e3..b4d23d42 100644 --- a/server/szurubooru/rest/app.py +++ b/server/szurubooru/rest/app.py @@ -31,7 +31,7 @@ def _get_headers(env): def _create_context(env): method = env['REQUEST_METHOD'] - path = urllib.parse.unquote('/' + env['PATH_INFO'].lstrip('/')) + path = '/' + env['PATH_INFO'].lstrip('/') headers = _get_headers(env) files = {} From e5f250260d6a4adc611a4ca758ceda2b91f2f314 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 20 Jan 2017 23:51:51 +0100 Subject: [PATCH 008/159] server: make gunicorn friendly --- server/host-waitress | 3 +-- server/szurubooru/facade.py | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/server/host-waitress b/server/host-waitress index 75980818..ce82430f 100755 --- a/server/host-waitress +++ b/server/host-waitress @@ -10,7 +10,7 @@ import argparse import os.path import sys import waitress -from szurubooru.facade import create_app +from szurubooru.facade import app def main(): parser = argparse.ArgumentParser('Starts szurubooru using waitress.') @@ -19,7 +19,6 @@ def main(): parser.add_argument('--host', help='IP to listen on', default='0.0.0.0') args = parser.parse_args() - app = create_app() waitress.serve(app, host=args.host, port=args.port) if __name__ == '__main__': diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index e29aae13..7bec33aa 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -113,3 +113,6 @@ def create_app(): rest.errors.handle(sqlalchemy.orm.exc.StaleDataError, _on_stale_data_error) return rest.application + + +app = create_app() From 2ab559c7e56b1e83cc324f3217b4af4d322257cc Mon Sep 17 00:00:00 2001 From: rr- Date: Sat, 21 Jan 2017 00:08:29 +0100 Subject: [PATCH 009/159] docs/install: describe how to run with gunicorn --- INSTALL.md | 46 +++++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index 1b3ebbae..5e606c6a 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -134,26 +134,25 @@ It is recommended to rebuild the frontend after each change to configuration. tries not to impose any networking configurations on the user, so it is the user's responsibility to wire these to their web server. -Below are described the methods to integrate the API into a web server: +The static files are located in the `client/public/data` directory and are +meant to be exposed directly to the end users. -1. Run API locally with `waitress`, and bind it with a reverse proxy. In this - approach, the user needs to (from within `virtualenv`) install `waitress` - with `pip install waitress` and then start `szurubooru` with `./host-waitress` - from within the `server/` directory (see `--help` for details). Then the - user needs to add a virtual host that delegates the API requests to the - local API server, and the browser requests to the `client/public/` - directory. -2. Alternatively, Apache users can use `mod_wsgi`. -3. Alternatively, users can use other WSGI frontends such as `gunicorn` or - `uwsgi`, but they'll need to write wrapper scripts themselves. +The API should be exposed using WSGI server such as `waitress`, `gunicorn` or +similar. Other configurations might be possible but I didn't pursue them. Note that the API URL in the virtual host configuration needs to be the same as the one in the `config.yaml`, so that client knows how to access the backend! #### Example -**nginx configuration** - wiring API `http://great.dude/api/` to -`localhost:6666` to avoid fiddling with CORS: +In this example: + +- The booru is accessed from `http://great.dude/` +- The API is accessed from `http://great.dude/api` +- The API server listens locally on port 6666, and is proxied by nginx +- The static files are served from `/srv/www/booru/client/public/data` + +**nginx configuration**: ```nginx server { @@ -169,7 +168,7 @@ server { } } location / { - root /home/rr-/src/maintained/szurubooru/client/public; + root /srv/www/booru/client/public; try_files $uri /index.htm; } } @@ -181,8 +180,21 @@ server { api_url: 'http://big.dude/api/' base_url: 'http://big.dude/' data_url: 'http://big.dude/data/' -data_dir: '/home/rr-/src/maintained/szurubooru/client/public/data' +data_dir: '/srv/www/booru/client/public/data' ``` -Then the backend is started with `host-waitress` from within `virtualenv` and -`./server/` directory. +To run the server using `waitress`: + +```console +user@host:szuru/server$ source python_modules/bin/activate +(python_modules) user@host:szuru/server$ pip install waitress +(python_modules) user@host:szuru/server$ waitress-serve --port 6666 szurubooru.facade:app +``` + +or `gunicorn`: + +```console +user@host:szuru/server$ source python_modules/bin/activate +(python_modules) user@host:szuru/server$ pip install gunicorn +(python_modules) user@host:szuru/server$ gunicorn szurubooru.facade:app -b 127.0.0.1:6666 +``` From 783171729f2cadd8062292fb6f9ecc8d5fc497fa Mon Sep 17 00:00:00 2001 From: rr- Date: Sat, 21 Jan 2017 00:08:55 +0100 Subject: [PATCH 010/159] server: remove unneeded waitress wrapper --- server/host-waitress | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100755 server/host-waitress diff --git a/server/host-waitress b/server/host-waitress deleted file mode 100755 index ce82430f..00000000 --- a/server/host-waitress +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env python3 - -''' -Script facade for direct execution with waitress WSGI server. -Note that szurubooru can be also run using ``python -m szurubooru``, when in -the repository's root directory. -''' - -import argparse -import os.path -import sys -import waitress -from szurubooru.facade import app - -def main(): - parser = argparse.ArgumentParser('Starts szurubooru using waitress.') - parser.add_argument( - '-p', '--port', type=int, help='port to listen on', default=6666) - parser.add_argument('--host', help='IP to listen on', default='0.0.0.0') - args = parser.parse_args() - - waitress.serve(app, host=args.host, port=args.port) - -if __name__ == '__main__': - main() From 9b27e113b37ccefca54c682ac7cfdafc5beead46 Mon Sep 17 00:00:00 2001 From: rr- Date: Sat, 21 Jan 2017 00:20:32 +0100 Subject: [PATCH 011/159] server/search: escape backslashes in search --- server/szurubooru/search/configs/util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/szurubooru/search/configs/util.py b/server/szurubooru/search/configs/util.py index 451179b9..7cb36e9f 100644 --- a/server/szurubooru/search/configs/util.py +++ b/server/szurubooru/search/configs/util.py @@ -6,8 +6,9 @@ from szurubooru.search import criteria def wildcard_transformer(value): return (value - .replace('%', r'\%') - .replace('_', r'\_') + .replace('\\', '\\\\') + .replace('%', '\\%') + .replace('_', '\\_') .replace('*', '%')) From 0cfc9bcafd716c736fab307cf02e3a34fa393cd5 Mon Sep 17 00:00:00 2001 From: rr- Date: Wed, 25 Jan 2017 17:10:19 +0100 Subject: [PATCH 012/159] server/posts: fix handling corrupt files In case of a ProcessingError, the image dimensions are set to None. But after that, they are compared with 0, which resulted in a TypeError. --- server/szurubooru/func/posts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 755090d7..49f384c6 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -334,7 +334,8 @@ def update_post_content(post, content): except errors.ProcessingError: post.canvas_width = None post.canvas_height = None - if post.canvas_width <= 0 or post.canvas_height <= 0: + if (post.canvas_width is not None and post.canvas_width <= 0) \ + or (post.canvas_height is not None and post.canvas_height <= 0): post.canvas_width = None post.canvas_height = None setattr(post, '__content', content) From f42fbbdc5657686b1061a0d27666b9a4f395c88c Mon Sep 17 00:00:00 2001 From: rr- Date: Wed, 25 Jan 2017 17:12:06 +0100 Subject: [PATCH 013/159] server/images: support webm with multiple streams --- server/szurubooru/func/images.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/szurubooru/func/images.py b/server/szurubooru/func/images.py index c7dfb34e..a04d20dd 100644 --- a/server/szurubooru/func/images.py +++ b/server/szurubooru/func/images.py @@ -36,6 +36,7 @@ class Image: '-i', '{path}', '-f', 'image2', '-vf', _SCALE_FIT_FMT.format(width=width, height=height), + '-map', '0:v:0', '-vframes', '1', '-vcodec', 'png', '-', @@ -56,6 +57,7 @@ class Image: return self._execute([ '-i', '{path}', '-f', 'image2', + '-map', '0:v:0', '-vframes', '1', '-vcodec', 'png', '-', @@ -68,6 +70,7 @@ class Image: '-i', '{path}', '-f', 'image2', '-filter_complex', 'overlay', + '-map', '0:v:0', '-vframes', '1', '-vcodec', 'mjpeg', '-', @@ -106,5 +109,6 @@ class Image: ], program='ffprobe').decode('utf-8')) assert 'format' in self.info assert 'streams' in self.info - if len(self.info['streams']) != 1: - raise errors.ProcessingError('Multiple video streams detected.') + if len(self.info['streams']) < 1: + logger.warning('The video contains no video streams.') + raise errors.ProcessingError('The video contains no video streams.') From aa1faa3ccbfbe8b5ee691a113aa52bce7e28ebba Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 18:21:21 +0100 Subject: [PATCH 014/159] server/image-hash: improve exception handling --- server/szurubooru/api/post_api.py | 10 +++- server/szurubooru/errors.py | 4 ++ server/szurubooru/facade.py | 8 +++ server/szurubooru/func/image_hash.py | 77 ++++++++++++++++++---------- server/szurubooru/rest/errors.py | 5 ++ 5 files changed, 74 insertions(+), 30 deletions(-) diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 5b38b18e..cbf8f27e 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -1,5 +1,5 @@ import datetime -from szurubooru import search, db +from szurubooru import search, db, errors from szurubooru.rest import routes from szurubooru.func import ( auth, tags, posts, snapshots, favorites, scores, util, versions) @@ -211,6 +211,12 @@ def get_posts_around(ctx, params): def get_posts_by_image(ctx, _params=None): auth.verify_privilege(ctx.user, 'posts:reverse_search') content = ctx.get_file('content', required=True) + + try: + lookalikes = posts.search_by_image(content) + except (errors.ThirdPartyError, errors.ProcessingError): + lookalikes = [] + return { 'exactPost': _serialize_post(ctx, posts.search_by_image_exact(content)), @@ -220,6 +226,6 @@ def get_posts_by_image(ctx, _params=None): 'distance': lookalike.distance, 'post': _serialize_post(ctx, lookalike.post), } - for lookalike in posts.search_by_image(content) + for lookalike in lookalikes ], } diff --git a/server/szurubooru/errors.py b/server/szurubooru/errors.py index f7edf85d..4fbb67b6 100644 --- a/server/szurubooru/errors.py +++ b/server/szurubooru/errors.py @@ -46,3 +46,7 @@ class MissingRequiredParameterError(ValidationError): class InvalidParameterError(ValidationError): pass + + +class ThirdPartyError(BaseError): + pass diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index 7bec33aa..30e5836e 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -44,6 +44,13 @@ def _on_processing_error(ex): raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error') +def _on_third_party_error(ex): + raise _map_error( + ex, + rest.errors.HttpInternalServerError, + 'Server configuration error') + + def _on_stale_data_error(_ex): raise rest.errors.HttpConflict( name='IntegrityError', @@ -110,6 +117,7 @@ def create_app(): rest.errors.handle(errors.IntegrityError, _on_integrity_error) rest.errors.handle(errors.NotFoundError, _on_not_found_error) rest.errors.handle(errors.ProcessingError, _on_processing_error) + rest.errors.handle(errors.ThirdPartyError, _on_third_party_error) rest.errors.handle(sqlalchemy.orm.exc.StaleDataError, _on_stale_data_error) return rest.application diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index ebd96cd5..8f8c4d72 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -1,10 +1,13 @@ +import logging import elasticsearch import elasticsearch_dsl +import xml.etree from image_match.elasticsearch_driver import SignatureES -from szurubooru import config +from szurubooru import config, errors # pylint: disable=invalid-name +logger = logging.getLogger(__name__) es = elasticsearch.Elasticsearch([{ 'host': config.config['elasticsearch']['host'], 'port': config.config['elasticsearch']['port'], @@ -12,6 +15,28 @@ es = elasticsearch.Elasticsearch([{ session = SignatureES(es, index='szurubooru') +def _safe_blanket(default_param_factory): + def wrapper_outer(target_function): + def wrapper_inner(*args, **kwargs): + try: + return target_function(*args, **kwargs) + except elasticsearch.exceptions.NotFoundError: + # index not yet created, will be created dynamically by + # add_image() + return default_param_factory() + except elasticsearch.exceptions.ElasticsearchException as ex: + logger.warning('Problem with elastic search: %s' % ex) + raise errors.ThirdPartyError( + 'Error connecting to elastic search.') + except xml.etree.ElementTree.ParseError as ex: + # image-match issue #60 + raise errors.ProcessingError('Not an image.') + except Exception as ex: + raise errors.ThirdPartyError('Unknown error (%s).' % ex) + return wrapper_inner + return wrapper_outer + + class Lookalike: def __init__(self, score, distance, path): self.score = score @@ -19,39 +44,37 @@ class Lookalike: self.path = path +@_safe_blanket(lambda: None) def add_image(path, image_content): if not path or not image_content: return session.add_image(path=path, img=image_content, bytestream=True) +@_safe_blanket(lambda: None) def delete_image(path): if not path: return - try: - es.delete_by_query( - index=session.index, - doc_type=session.doc_type, - body={'query': {'term': {'path': path}}}) - except elasticsearch.exceptions.NotFoundError: - pass + es.delete_by_query( + index=session.index, + doc_type=session.doc_type, + body={'query': {'term': {'path': path}}}) +@_safe_blanket(lambda: []) def search_by_image(image_content): - try: - for result in session.search_image( - path=image_content, # sic - bytestream=True): - yield Lookalike( - score=result['score'], - distance=result['dist'], - path=result['path']) - except elasticsearch.exceptions.ElasticsearchException: - raise - except Exception: - yield from [] + ret = [] + for result in session.search_image( + path=image_content, # sic + bytestream=True): + ret.append(Lookalike( + score=result['score'], + distance=result['dist'], + path=result['path'])) + return ret +@_safe_blanket(lambda: None) def purge(): es.delete_by_query( index=session.index, @@ -59,12 +82,10 @@ def purge(): body={'query': {'match_all': {}}}) +@_safe_blanket(lambda: set()) def get_all_paths(): - try: - search = ( - elasticsearch_dsl.Search( - using=es, index=session.index, doc_type=session.doc_type) - .source(['path'])) - return set(h.path for h in search.scan()) - except elasticsearch.exceptions.NotFoundError: - return set() + search = ( + elasticsearch_dsl.Search( + using=es, index=session.index, doc_type=session.doc_type) + .source(['path'])) + return set(h.path for h in search.scan()) diff --git a/server/szurubooru/rest/errors.py b/server/szurubooru/rest/errors.py index 3c40b4c5..b0f5b882 100644 --- a/server/szurubooru/rest/errors.py +++ b/server/szurubooru/rest/errors.py @@ -47,5 +47,10 @@ class HttpMethodNotAllowed(BaseHttpError): reason = 'Method Not Allowed' +class HttpInternalServerError(BaseHttpError): + code = 500 + reason = 'Internal Server Error' + + def handle(exception_type, handler): error_handlers[exception_type] = handler From ec9c70ba680795314c2af87169c4eabff934cc13 Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 19:38:53 +0100 Subject: [PATCH 015/159] server/facade: disable elasticsearch logs Errors are covered by new safety mechanisms in image hash. --- server/szurubooru/facade.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index 30e5836e..5aae121f 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -101,6 +101,7 @@ def create_app(): ''' Create a WSGI compatible App object. ''' validate_config() coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s') + logging.getLogger('elasticsearch').disabled = True if config.config['debug']: logging.getLogger('szurubooru').setLevel(logging.INFO) if config.config['show_sql']: From 8be0e731a7224ec928cc20c70d80eddc04c11eb2 Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 19:39:30 +0100 Subject: [PATCH 016/159] server/facade: run without elasticsearch ...but don't let user upload any images until they fix their configuration --- server/szurubooru/facade.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index 5aae121f..26a64473 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -110,7 +110,11 @@ def create_app(): purge_thread = threading.Thread(target=purge_old_uploads) purge_thread.daemon = True purge_thread.start() - posts.populate_reverse_search() + + try: + posts.populate_reverse_search() + except errors.ThirdPartyError: + pass rest.errors.handle(errors.AuthError, _on_auth_error) rest.errors.handle(errors.ValidationError, _on_validation_error) From 07d0b43d4c75738a68651dccf748320a3967f06c Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 19:39:57 +0100 Subject: [PATCH 017/159] server/posts: reduce warnings from sqlalchemy ...regarding empty IN() statements --- server/szurubooru/func/posts.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 49f384c6..de58db02 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -378,10 +378,13 @@ def update_post_relations(post, new_post_ids): 'A relation must be numeric post ID.') old_posts = post.relations old_post_ids = [int(p.post_id) for p in old_posts] - new_posts = db.session \ - .query(db.Post) \ - .filter(db.Post.post_id.in_(new_post_ids)) \ - .all() + if new_post_ids: + new_posts = db.session \ + .query(db.Post) \ + .filter(db.Post.post_id.in_(new_post_ids)) \ + .all() + else: + new_posts = [] if len(new_posts) != len(new_post_ids): raise InvalidPostRelationError('One of relations does not exist.') if post.post_id in new_post_ids: From af6c35ed6ba25dd4a1c0c51481652bc80bfac9d1 Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 19:40:21 +0100 Subject: [PATCH 018/159] server/rest: rollback session on query exception Kills complaints from sqlalchemy when an error happens during insertion/update hook. --- server/szurubooru/rest/app.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/server/szurubooru/rest/app.py b/server/szurubooru/rest/app.py index b4d23d42..391d297f 100644 --- a/server/szurubooru/rest/app.py +++ b/server/szurubooru/rest/app.py @@ -93,6 +93,9 @@ def application(env, start_response): hook(ctx) try: response = handler(ctx, match.groupdict()) + except: + ctx.session.rollback() + raise finally: for hook in middleware.post_hooks: hook(ctx) From cce543e0b692e791adfc56f02ad0e90bb2a62172 Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 19:45:18 +0100 Subject: [PATCH 019/159] server/posts: commit reverse search population --- server/szurubooru/facade.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index 26a64473..c85ac1ae 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -1,12 +1,10 @@ -''' Exports create_app. ''' - import os import time import logging import threading import coloredlogs import sqlalchemy.orm.exc -from szurubooru import config, errors, rest +from szurubooru import config, db, errors, rest from szurubooru.func import posts, file_uploads # pylint: disable=unused-import from szurubooru import api, middleware @@ -113,6 +111,7 @@ def create_app(): try: posts.populate_reverse_search() + db.session.commit() except errors.ThirdPartyError: pass From e92bd2fd8007cde5b5e6f39d9ece09cb4119c468 Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 20:03:24 +0100 Subject: [PATCH 020/159] server/tags: fix getting default category name No categories? Should have thrown an error rather than returning None. --- server/szurubooru/func/tag_categories.py | 4 ++-- server/szurubooru/tests/func/test_tag_categories.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index bca15df2..a9169dec 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -139,8 +139,8 @@ def get_default_category(lock=False): def get_default_category_name(): if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY): return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY) - default_category = try_get_default_category() - default_category_name = default_category.name if default_category else None + default_category = get_default_category() + default_category_name = default_category.name cache.put(DEFAULT_CATEGORY_NAME_CACHE_KEY, default_category_name) return default_category_name diff --git a/server/szurubooru/tests/func/test_tag_categories.py b/server/szurubooru/tests/func/test_tag_categories.py index 70f0aa0e..cf74c2a5 100644 --- a/server/szurubooru/tests/func/test_tag_categories.py +++ b/server/szurubooru/tests/func/test_tag_categories.py @@ -193,7 +193,8 @@ def test_get_default_category_name(tag_category_factory): assert tag_categories.get_default_category_name() == category1.name db.session.query(db.TagCategory).delete() cache.purge() - assert tag_categories.get_default_category_name() is None + with pytest.raises(tag_categories.TagCategoryNotFoundError): + tag_categories.get_default_category_name() def test_get_default_category_name_caching(tag_category_factory): From f2fd769767b5cadd139cc7f7fac350d1d0865b0b Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 20:06:20 +0100 Subject: [PATCH 021/159] server/migrations: fix imports for alembic `alembic revision -m 'blah blah'` rightfully complained about imports (in case of `upgrade`, that module was being populated by some other module.) --- .../migrations/versions/9837fc981ec7_add_order_to_tag_names.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/szurubooru/migrations/versions/9837fc981ec7_add_order_to_tag_names.py b/server/szurubooru/migrations/versions/9837fc981ec7_add_order_to_tag_names.py index 10969fe6..39f9edff 100644 --- a/server/szurubooru/migrations/versions/9837fc981ec7_add_order_to_tag_names.py +++ b/server/szurubooru/migrations/versions/9837fc981ec7_add_order_to_tag_names.py @@ -7,6 +7,7 @@ Created at: 2016-08-28 19:03:59.831527 import sqlalchemy as sa from alembic import op +import sqlalchemy.ext.declarative revision = '9837fc981ec7' From accdb51c0b969901d4cfe5f1895ca74e61c582f2 Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 20:26:22 +0100 Subject: [PATCH 022/159] server/migrations: add default tag category --- .../5f00af3004a4_add_default_tag_category.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 server/szurubooru/migrations/versions/5f00af3004a4_add_default_tag_category.py diff --git a/server/szurubooru/migrations/versions/5f00af3004a4_add_default_tag_category.py b/server/szurubooru/migrations/versions/5f00af3004a4_add_default_tag_category.py new file mode 100644 index 00000000..d77dc37b --- /dev/null +++ b/server/szurubooru/migrations/versions/5f00af3004a4_add_default_tag_category.py @@ -0,0 +1,62 @@ +''' +Add default tag category + +Revision ID: 5f00af3004a4 +Created at: 2017-02-02 20:06:13.336380 +''' + +import sqlalchemy as sa +from alembic import op +import sqlalchemy.ext.declarative +import sqlalchemy.orm.session + + +revision = '5f00af3004a4' +down_revision = '9837fc981ec7' +branch_labels = None +depends_on = None + + +Base = sa.ext.declarative.declarative_base() + + +class TagCategory(Base): + __tablename__ = 'tag_category' + __table_args__ = {'extend_existing': True} + + tag_category_id = sa.Column('id', sa.Integer, primary_key=True) + version = sa.Column('version', sa.Integer, nullable=False) + name = sa.Column('name', sa.Unicode(32), nullable=False) + color = sa.Column('color', sa.Unicode(32), nullable=False) + default = sa.Column('default', sa.Boolean, nullable=False) + + __mapper_args__ = { + 'version_id_col': version, + 'version_id_generator': False, + } + + +def upgrade(): + session = sa.orm.session.Session(bind=op.get_bind()) + if session.query(TagCategory).count() == 0: + category = TagCategory() + category.name = 'default' + category.color = 'default' + category.version = 1 + category.default = True + session.add(category) + session.commit() + + +def downgrade(): + session = sa.orm.session.Session(bind=op.get_bind()) + default_category = session \ + .query(TagCategory) \ + .filter(TagCategory.name == 'default') \ + .filter(TagCategory.color == 'default') \ + .filter(TagCategory.version == 1) \ + .filter(TagCategory.default == True) \ + .one_or_none() + if default_category: + session.delete(default_category) + session.commit() From f828c375e6325583c2d5c0eea4230cd426dcdc62 Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 2 Feb 2017 21:52:52 +0100 Subject: [PATCH 023/159] server/posts: fix reverse search late evaluation Uploading webms caused 'Not an image.' error to be shown, cause generators are evaluated lazily, so the `catch` never worked. --- server/szurubooru/func/posts.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index de58db02..51d11f31 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -553,11 +553,13 @@ def search_by_image_exact(image_content): def search_by_image(image_content): + ret = [] for result in image_hash.search_by_image(image_content): - yield PostLookalike( + ret.append(PostLookalike( score=result.score, distance=result.distance, - post=get_post_by_id(result.path)) + post=get_post_by_id(result.path))) + return ret def populate_reverse_search(): From b21ffac8205269b909275f6765560719bc399db2 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 3 Feb 2017 19:22:33 +0100 Subject: [PATCH 024/159] server/scripts: make pytest happier --- server/test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/test b/server/test index df8ae5f5..6d7bb6de 100755 --- a/server/test +++ b/server/test @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import pytest import sys -pytest.main(' '.join([ +pytest.main([ '--cov-report=term-missing', '--cov=szurubooru', - ] + (sys.argv[1:] or ['szurubooru']))) + ] + (sys.argv[1:] or ['szurubooru'])) From 894cd295119b33fe786b51558b8b53033813ae27 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 3 Feb 2017 19:53:10 +0100 Subject: [PATCH 025/159] server/tests: test image hash --- config.yaml.dist | 1 + server/szurubooru/func/image_hash.py | 10 +++++++- .../tests/api/test_post_creating.py | 3 ++- .../szurubooru/tests/assets/jpeg-similar.jpg | Bin 0 -> 10810 bytes server/szurubooru/tests/conftest.py | 4 +-- .../szurubooru/tests/func/test_image_hash.py | 24 ++++++++++++++++++ server/szurubooru/tests/func/test_posts.py | 11 ++------ 7 files changed, 40 insertions(+), 13 deletions(-) create mode 100644 server/szurubooru/tests/assets/jpeg-similar.jpg create mode 100644 server/szurubooru/tests/func/test_image_hash.py diff --git a/config.yaml.dist b/config.yaml.dist index a38393a0..1f4a2307 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -36,6 +36,7 @@ smtp: elasticsearch: host: localhost port: 9200 + index: szurubooru limits: users_per_page: 20 diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index 8f8c4d72..ae368488 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -12,7 +12,10 @@ es = elasticsearch.Elasticsearch([{ 'host': config.config['elasticsearch']['host'], 'port': config.config['elasticsearch']['port'], }]) -session = SignatureES(es, index='szurubooru') + + +def _get_session(): + return SignatureES(es, index=config.config['elasticsearch']['index']) def _safe_blanket(default_param_factory): @@ -48,6 +51,7 @@ class Lookalike: def add_image(path, image_content): if not path or not image_content: return + session = _get_session() session.add_image(path=path, img=image_content, bytestream=True) @@ -55,6 +59,7 @@ def add_image(path, image_content): def delete_image(path): if not path: return + session = _get_session() es.delete_by_query( index=session.index, doc_type=session.doc_type, @@ -64,6 +69,7 @@ def delete_image(path): @_safe_blanket(lambda: []) def search_by_image(image_content): ret = [] + session = _get_session() for result in session.search_image( path=image_content, # sic bytestream=True): @@ -76,6 +82,7 @@ def search_by_image(image_content): @_safe_blanket(lambda: None) def purge(): + session = _get_session() es.delete_by_query( index=session.index, doc_type=session.doc_type, @@ -84,6 +91,7 @@ def purge(): @_safe_blanket(lambda: set()) def get_all_paths(): + session = _get_session() search = ( elasticsearch_dsl.Search( using=es, index=session.index, doc_type=session.doc_type) diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py index 341990c8..9737a73b 100644 --- a/server/szurubooru/tests/api/test_post_creating.py +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -258,7 +258,8 @@ def test_omitting_optional_field( def test_errors_not_spending_ids( - config_injector, tmpdir, context_factory, read_asset, user_factory): + config_injector, tmpdir, context_factory, read_asset, user_factory, + skip_post_hashing): config_injector({ 'data_dir': str(tmpdir.mkdir('data')), 'data_url': 'example.com', diff --git a/server/szurubooru/tests/assets/jpeg-similar.jpg b/server/szurubooru/tests/assets/jpeg-similar.jpg new file mode 100644 index 0000000000000000000000000000000000000000..af6120924c7c1d5a28cc9c2a59388d2b758fab4c GIT binary patch literal 10810 zcmbtZRag{Uv>m!T1ZfyTN*d|zl9+*^6oyXeP625cx{>Z0IusP7L%MV5lt%LUKiu!T z?{}Zaz1G_MoQL!6Z?Bhymkj`+s-lV_00{{Rp#1LuUY-HW@6BwT-0Ah4ogL`8bY)by z=vCh>xWwc!w`OHGdnP3GCKPLa2zoG#0Z~miFBLV&s{y);_7$~S{$e2ige_64A zO&}wo{1g5c=qUf{|IGDn_)1y6p+b=z;u3W{sHovLmMXLbX^Ny zmH@c_@W@0cL;xwkcZ*K}-E>>QdC10fl1Bf zM+j_^4^uWfmR#cgeyZ064BpL3=LA={CHmM+4*#(wiM0Qk~P^ltun#R zP(UTP>?5i3+Dxp9rq`z*OM8wxKONpyNF#4n2y-^3)`DMaVQNc*o@;1z&60GM4oZ~- z%#*uaFD+R{)Nh;L3h=3 zE0JS0?PUt&F&w!+$Ts{i&bzhK-!bfE!|5UzHDHk4`7{c`gqFIRwo+%Sr`ZeBr(Q#K z$C|cW*^>-)5BL_7QAf`Ee|E`4tKp6~3jzx$e@7>Y=~RAkxokdVxKs&bpFPp`B`1mv zXMRQ64%@2V7aLpq8C4Rk0brZ?DDIwSZd)hJN`|ouK0W|xz-9IMTDP)iotA7trENIRTj5V> z(KlZx^_!i{tD`P#EPA&QGsbxV+qoZ!O_|~aa25W6b~&=AXFmx7%^m$y1=L02wAm(C zNW{Q&d8S!*JBz(ic>xd%oNeJ39VPp{weafRxqksP-JCzvrbZ%`UjTGvt+$A}+Nf7| zH@CMh0PgjKyBEM)nyPxqvd0%dnNAhapg7XQ!?thGchO&2q;4QFzZU@Q)AY&BV?gJ_ zGu=qv0qaaoz?0H?1;Q=o;h(trJ(L={zW&%1j`spc?oLksmAGym_>iC<9{Duc_zzHL z+RJ$Iuh!e3y<4rI-Uo58_4DUj32iR~Mvk8~HgBF(YXMWfsV|Z5^CCA!5qEK2rk1TA z`#Jhtmgthp9&jH1oDuCh^js+N{?p-6jO0?!Vplyjrmh;(q<;;hX;!5U94|{|YLjTO z;l#&4DKZFnGX%I$!YEgU-BOGL^=wJ=cPwd*@v+3BWAVMIeXswX>BO-+B)Q^K7%<4f$`xLQ6xcwoDbP_Nl@f{Y*Kx zj?;on0?}*p3bmd8ekji?^7&n`Jq20T*>oc-Sc*gkM;xSxVoz}z8R^U7xC#S_`BN|| z_wC3W?uv9~CApt|e_NW?5})1o_EY^D+kgmig*20SUXdEUX%t@&dL=$+g8W{v|>$SW(+hceQxlPQqc(cdT^Wc zreyl3;Dharpw#4CR5t0%SE?i*$k0xG^-t~bI&6wxHOSF5i6c3q^GUi^uB5ByY2jsZ zz!6K%3^|S(tGHnwHGrzMUn(H-D1IeSm!n)qK~xei6MIaIlX$S7qI!Wv;B@du7M=); zTU4<_x2vt`N%EfS=?{6;`)WGU>o0@c1s$d?LmPFM?ZnNx6T8JbG^OOZcG#jxL?#C~ zQRztujrB!kSf^yyDYGt|b=PV7kCQty9$YK|irecL_U}g0zxp} zO~|)0xb<61+u?3$8jVhT$Bg)4(tl5aQ#SlJj8u zx{BnZqcoh54Vn9ZpH-fNVKR-J-g`ZQ*bl z`KD@aWESM5v;~m9fAcSz`iL1L+l=NoPCGL+wI3n*=4;N|qFMVeD z;F#mw=fgbUVO9&l!u333pj7^)<_qBLQCvCh*lnP@_2A~YE6AzL7&(B)?U4WK<|*iM zUd!?^^XVVQIKS=%FcNfDonj{__B1Q;0zmu<)S7vI(zjZ^e+bME^qsx&@s*g<(ukRk zjo7_jeCTSvbMoz*D!=luoHC04OQdbESQXDA=*rF^zGk7PcZO(Nh0Uhk5UQR@Ms*I3 z21veFdCvV54^~J7KunC0`S_|aQT^fpOe-F`l6jI88;-vDhq~qXpN5TQT_y-TNNlq< zp7EP8?e6&w4(uyDGM+FPQ3xuoMZ11k6-KLaI*+!r*qg6QmQrq4S{jaQO^~gBR;($q zodcq)mUwX3RVN{w9V@G-`7#N-zCV`$-xp(u*xs>f7CRQ-5)WbgqJWupI4uIlpI>?cu~X;uIqwpcfJS9m8o56;H1Ytf#ith_=98*U8w z58$_0n-maZjzED%5#Z1{ZZLRbTU;1rk(e`WQgzofs0fEbRG=X-7m3Pa3iX&8I>`oa; zlps#!!J2Vwrc3pSr^*W?l=xqx4BzFuA*G~-amBt{q7bq$rJo`D7`~0(gshy^5Vf~f z?fI7;Tivs&WiR3rG^?GnGgHYy->cO5bN|vBwKMCP^QHcqLD7!KRwAf(HAZ-Baoy{F z(DFrm?pa3I_}07DH(DO(Ok5H*dD+%A3BqVkkr9Sn<`B4d7UXJamYt!7I@(J-o+WpL z_;O4<8i2c!nT@S_5=}A3)%nT8LCf_|Zjr%ZsTBi+CydjhVlL{>at>d`FuXr^5l3p5 zAM!a@L^n>^dLW2^T%G6^|@Mdkz7oFd3)d(9*fmT^Pj3(VWcpTW~l2UL@^tX75rt(lA$Ht_=_G2=CTTfgjB^d?!KYci!rENx`vR~b@>~ayKd^R_# z6R&`*;$KSP*Aj31W#o$br>%r2I;JcGzMYh|Dri)u!G|6sET1BhKmNUZoipcK+BoEA zrxsH7>b)FO=|@e%#+_6;2?KG6AtB7L#Ed958Eex!-w%?xljgK0qUb3GRiu&)XJzOh zIv3tlY_#fYtz3>UmNR9)HWb2bW|tu|tXS3eg#Gf?ZD$A<9%>5%^;5H;!1(-A<;Y_g z>}gW{O1v-o_g8qYcX`VXWZ>u^l^@(fBJg6@))AhwSXgm}4<&o*4{`E>_+PeZ-auN& zt5_adn^l4=5#32xu((Si|CG*iu9tcvv)XS7pDjsCm)O`1zWmhRuwHJZnd>eQv9g1+ zEwg_EPte)D%M9fPwJ$f!tD+(@I^l;5`dGquFxSjXw_6XAY`+P>|p_g+J;0vN7 zlq~U?2)|yScBZSezF4kq{xNxYU>ngkPgYpkMB%6Yj``CCVPzGg0`0ECZR*pd;lr!_!%IX2Il9$ERir{D7m zi8Weq(vwo)1wbMPY|i3E7F0egHi^O+__N&`*U(@ELZd(r;$3l-x!Mfyv398im5T+v z03c3-HPbFiRk{kXz+v?`|IhnwZu_*8=h?{XAgKLdyZxed)uf&fo++arPKQ19-#1+P zy`T0)6RS!VY+IFYI4CpqvSNGK#Ro#nlfNig;jAd?fQ-tbNVPc}CbciJ)DNjhoNIjIY)p^LyYaIdPS zYu^^{eoQ@neX~U1uTJ=Cgv#8^h4|cCmF4)csO@oaJhewm0fWZb_yWk8h3 zhsw(`91+`7hJ7*5CdGy=-4?vSC?xODzT(oH+pk~9^qQGAduFwJuec00_vAMCqi(Wb zDiI!RQ-GaxG*|*ws?>EQ2X0;Jl@WVz=if7T^ zi@$PZXzaLPAD7~{l%hXm#;YuX0B|}8)YQZ>Q6nlLhjH`E{9&p$x9#sJnF!@$gLn~x z+-&^*URI=(7HNjrBJsGM4N4pSAd0DnxvFm>N!!=SLE@D~+n;+m#2l#%`0|u6Xvl&H zXiHf~E>H1`A0%kV^%IY-$vi>RlZN`1HGr|%$EyzcPy^N19h?VzG6E&~X~7tm7QXz< z5Q5jpAMg&AX2pils9X=`r#@XMp~8FC@k^fHm}tZ8^YavRmZ|J%Q@dk)$<&s_r~7O* zDi6R+2^Ut;YGxK9g1OLuhg=mG>Pn+Q0=y$@!x7b^LRO~BQd$o22eKkBj`0uY=zu?H;%)j%EY&e9s4V}zIA{~7e7~i{M z*1cjn%d0q??nQ@vTmXBARW6eN8a=lGRIjb-B0#Ji`b)UVp_~Xc4J$xt)a%l=jpyYhK-^%yXDzSGlk3GY@xiA`Ste_rxLb3L$8s4o6vjRu975eF0AK!yIi~KONkr1xEqe3AO z<7#LtRFr6@ROckRt)xP)r6L-EpyS*A!>6)OF+!wR>>Bu6KSy!(0DjlI2|NCGT+&8e zU@dzipzHi$qk9Jfm{nrT!HcYjZvQs9sru7}2L054s6eDe_$1-n=nxTqnDVkkWJI)r zzT$_JDuA~LFeJw>Ph_vuJ(x4TQjWb5Ff5}F_%o~s`ix7cKP)qdsQIpJ#Hpjiq-~xg zJir7Zfj6AE0Tr}{3@c;q?;1%p+WZ3EJw3sB^CWXz<^hRr2eCx=2{#hmlkE2QIEC;% z(<0GSx}+-~F3ml4PLTq)#Vih_cVD#n+}r$b2<$V-hlvqW21)pkQrjbUjrYe)oW6$( zNRc1FG8&4gcL&>IML0|t?>oOF+bglSdO7Q;y+Of1Qel;l6o0eDaj}2(_|}yDj3)XI zx{rf^6nVO&hy$p8Svnk4ZH|@F@-w`6T>8M(L!)0jX27_Qj^9B5(y1aEwOyI?nI*bq zE`qS+3!O3su%VfJio_bErLUWmEX|M4*&oRg|9Y1~{Z+BCO5~*?uWoc|&@Z zWAVaZS(i^XsdE^n=hjlB+5T zTN0Z>a zx0In4W>XW_otv(*AQT0Q%W%3H6#b*MY>YS}jB(I#Medo)OmSg@gBT%Nx{$Yg%2S9? z72f9HF|2%ja6)4!=_T=Y9795a8--JsyhZdN(9{ht>EDilzNM0IJ3_HVpGmF1!A?{v z^h`~4JH?{RF1`SDHoIjfX;j%9jPU1+A}WosXA09xBy~Wo+uTiJvRA25PS<}Q`dG?1 zCa|xrsKO49B3)?YbUzZmDzL#t_R~_98 z{~KB6PL^D%IhU?cIx{a%AvEXpcB09MSXyACm;Er+!;vp;#-xR{*6_*>0_Ldn)8St2 z6c!I&j7{LFJZDlq9xtcalHj=GT*nj4WaURP4|A>H|DmxsuxG$777bM4P$F8RxxBsF2zsYO5HeI;&0CB7RJW$MhF1a#Dpq*hCA%9?$xkcnSNAa}IXDcM6k=xL$ z-&($Foz?HhTE1gXw(d0c0_fe>y))U{7~%?}klKtF5>r~|^YyiUxL&#cG$Zhr-W>f# zOobX9lSgGQq?7eIXD>d2)3VqwCzYmww|HuVh-mDP33kl=y{F?K)xUC{Pd-tLKY48Z z2KK5uw9J?nR|Ho{JP(Vmo%$+eo~j8qEdF@gt2NroPRyVqdQ#nltFhve;G`$EFtKWd~0c48~!s%P2=YU2x4R^AxQtRTY9l^)btXI zlStwZ9LKxs!Qp}4;k~S$k?0G6Q=PU2P4*yZObIkAK$Uqq-3*0eZWJaO(7ymsqlfsA z_9>qt59&4@Rk)|@!s?qn173u77_rY1S`VRGh- zBRU-`;KJ_k%okiZ_SZVcvgEaMuaJ-9sm&Mlr&%305a3hqM~|1N+$Pf>?Ktr=raqNm z`4D?^H!ieUaeeKJ<7D@1Bt2lfTqN^@kZd}hO0lqpq(rU1irZI@6@^2F_V4UiNA~L~ zucHr#@iGc>;}X2{8O{3|E_?z#@k$$*_@=I25%8 z_KmWKyC8&J6*kGzxYdnwCnY}v^=VcE3t~A^O~^1M>n3Zr8(_4W50G(?ASoYe$~Pkn z`rb29EMUzAHXZ&Y36D*H**lRqD?)fb=H{D$y=`4H*W}nUPRJZni-)S76HVo`@T>yd z>xmpjV^cSjjcVv@d=zNkoYzR$|Lhpg=5*~XA7rQ?^^2Xg2z>Hc{F)MaI7G}V@w=GY z|8%r3;Xo3ny>xx5x!e?Esx+L_$jV(UamZNHOp#N+l$Y7cS?%ZrU_Mcy!(P@alBXN zVuD`^NhoILRa46!@AKN2CIeE%0>2+TJ_gVyXwNIS9X;VoeCz)3NL#ge zZZZM^d(aH{{skSWHEu%r3q+tN)cW3BMJI8T%vx({cDcIh_GDr$JSBc0!-PLekD~Rj zzb*aUQP(Noc<6Es<|DZ1H-N^ie?rfd#WQc~1@3uCRslb&3%lmvv*d8x^ zE#f0S#+`O~1Z{ZMyRPsU^LbB#al(=I?k#dVYDD|%hVeuZA_G0UQcv&Pd5SB4mry$K zgvrkw0e=DpRD=2qt++Hxb}(8!v-*utt$w}Ytp{!p+mE1rPTs-0VyYOEd>*d2_fq(0 z5fEHF`b)2UDLGRZz*$Q26uFYk1Q_DZ!fjp2QAanw)HEUMHP>cjF4ee`)f@2HRIjg^ z8@i2K9Q)${ofQHZW|AVhjXsnYV?>j|0X;q+(K#Dl0G&wkL2~W<*GF!$GbK~a*sYU> z3gOXO*c;hoERR(@AGi4jDSPDql#r#4tP(=1JT)p*O-Xb-S9r#(GhZ19FKss(kxx1MnV z&1FJC+fcra9HyI6aJ={D;3K?p6VPUTbdLrV8{6V`Etau=W{=L12uk#2ny@E8FWZ8E z;^LczN!az|2s#-4AdePFA|v2TD>0|O97?SG*^D>W!iF*_IdO1D>Y&F?-w4qSD~`z9ai_Kkb#3 zl8>`lVHJK&K97vNGV)B{&$VXS9CZuZj%N0&D_dZLk!qSF+Zc1+5~xNZ16RXBjg%-- zzOi`ZBfG55fRj91r_AB%1pW~%b3@f3Elu9Uc`luCg3^q@xd;bKI;Zs}XGYFoAq6nr zCgVd_k&>M(jI${xwR1j}u;UK!~ zIP0LwQ~NIXO*J{@>TBMZtPVHBxn|LEw5A%Ts|eNVWI4vgI7E%OQ#Po`BTO#`7r!vBzKug!3sOPfEBz>#1 z5jG5=B9jOhlY+-b{n1LMHHl0bC>p^+E#SRYhTy?QbCn(W!8qqsv(s3#D?plXp@F$} zNJm9K;b|fZ`BmMJ;FZz z+W*_Co4xytlCqqEe4|=gheu8^-@8H<_RZthJSX~cM0DYAbA|P5+C%m0pP~W`(6Zm` z_Iu+uwDl5J9fe}@|8^MtlanmKU6CinD}GaX^;9%<#5X{7B;51kX`rmlo1n5-wn5d8(;yrNpfZ?1!$Q+@zn0n&RM?K zxvJKIg}6*>4yp-K3sgr#dY@E>+DR==nKSvYBuMVS9lxgqv0ItHymHq5qUE1VHt@~p z8-(8cYdO53N^(QYnSun(f|5z(i;|~4&4D3v&P;?J%H@=S{G0}TT&=*H8tQ(DI%x7i zQJ#3bz45LRXKj=bJMS2E3GkKk!OR&Yc}}bS4@<;HIM8jlUaaDyzPq^wf?e?ys-pyj zZA#(Pam*DW!d8=30LD*US$P_XV2JKKzNkXRH1$3tc5oY_U$GsUCqw_9$mGDIFQ7&* zzq64z^z1HNT?wr^!%WjpHilHr9om@)1FnHi~c}30Ax$u{-j@+juJn6b2rFR*e!_q#;MVt;- zK!>K(4=!~grvrSavBhpmlOnI<7X9;-7?(f#(8DL?E{XOytFFZDz z3SPldTGL3POyqMr<-eQ2mtu%5E5j#0Jwg}JT+ts>&6Ogv6n*VZKvV8lD1#Si>{)#L zj=ss3(=vu)CIY+B{+X}4I%<8RN=+q@X19(klbLLdZ985z{@hazqGACl2pHl~y`Y4q zv5bJXB9Pna9T`}9mT{t@8IMO~4b0-(Fe?tKQDQTG)aL&tWsUC(x=L8C3wec-uyS3sX7R?e0GB)@A_>z? zRW(#FH{AP6F=80^=aM`U%9A{$QWjX5QuU5*&$T^fy)~g2Dpf(ZOygq1UQbUT$)6XL zlE?q_J*p5CpFArSM=aW@ zwp3ukIy`05l-_CPg_fjHcWlNC+*m ziq*Gqk(WOiVTk8gJXkOIdNi8JPB51D;9@?IgP;!$1Izb$^<*%`kIE?cK#Zl2GFMy1aMK6@$4MLZ^tTiMz zR|%c){IS|NB9KE8YRG496>)%gmlK`tY>HbrcLJrVNgIOfx@(^gf_;}N&zuHoc-+5` zy_GN)!Ob(sr^?$gSSzmuJ=NjZotlz*_Cx#^uHIeTTd&{;)kzyxV}YOU|D?&+xax(s z`utXUyHoGrKPb)a@&l!W|7VVMcH5C!?I=W96}->q|KZm>hObB1DpkDdi9reS0|)lB zMYNytNT}LpB@21OY+WrYx7|jx2(cjb89J>Z58`h|bR_2QGHZZI6bPQ3qFBy`az)PN z)Y%^^JsZFi{P-u*^5eE88pW)c2Jx(~=2c4M)nU zub6G!ug|I+RiGa}kO9Hc{UjVE(c|oMb3g5}_{j+TC)@i~IHxpg{p&zK$jqWr%9&G& zICzJ^ve>X#fL=L0LWIUFpZo|+lW7%E0c^4RL2bdmvrzR876FMzKW6KB{}tEa1)x|s zuANLHNcK(++{>K!bKnEA^cNIVIJRDQt`HZMyePTDwz%z_8$iZW+eL zYCz6%T$v>q%CzrM9ij`Mk??Xfe01q?0;bZmud`iZ5;Tl)*v{Y&;Xa1L6F|5`j4f*Q zadKa*YVJtJW6i}TKr??P~Jg301t(1alHSpgl>|{Qdbe(#(-4a$~ literal 0 HcmV?d00001 diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 5dd47576..db34ee02 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -155,7 +155,7 @@ def tag_factory(): return factory -@pytest.yield_fixture(autouse=True) +@pytest.yield_fixture def skip_post_hashing(): with patch('szurubooru.func.image_hash.add_image'), \ patch('szurubooru.func.image_hash.delete_image'): @@ -163,7 +163,7 @@ def skip_post_hashing(): @pytest.fixture -def post_factory(): +def post_factory(skip_post_hashing): # pylint: disable=invalid-name def factory( id=None, diff --git a/server/szurubooru/tests/func/test_image_hash.py b/server/szurubooru/tests/func/test_image_hash.py new file mode 100644 index 00000000..3bbdeba2 --- /dev/null +++ b/server/szurubooru/tests/func/test_image_hash.py @@ -0,0 +1,24 @@ +from time import sleep +from szurubooru.func import image_hash + + +def test_hashing(read_asset, config_injector): + config_injector({'elasticsearch': {'index': 'szurubooru_test'}}) + image_hash.purge() + image_hash.add_image('test', read_asset('jpeg.jpg')) + + sleep(0.1) + + paths = image_hash.get_all_paths() + results_exact = image_hash.search_by_image(read_asset('jpeg.jpg')) + results_similar = image_hash.search_by_image(read_asset('jpeg-similar.jpg')) + + assert len(paths) == 1 + assert len(results_exact) == 1 + assert len(results_similar) == 1 + assert results_exact[0].path == 'test' + assert results_exact[0].score == 63 + assert results_exact[0].distance == 0 + assert results_similar[0].path == 'test' + assert results_similar[0].score == 26 + assert abs(results_similar[0].distance - 0.189390583) < 1e-8 diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index f44fb03b..a10cc6d9 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -290,15 +290,8 @@ def test_update_post_source_with_too_long_string(): ), ]) def test_update_post_content_for_new_post( - tmpdir, - config_injector, - post_factory, - read_asset, - is_existing, - input_file, - expected_mime_type, - expected_type, - output_file_name): + tmpdir, config_injector, post_factory, read_asset, is_existing, + input_file, expected_mime_type, expected_type, output_file_name): with patch('szurubooru.func.util.get_sha1'): util.get_sha1.return_value = 'crc' config_injector({ From fd30675124e2ba0fb07fa15af1bec7432cb9f485 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 3 Feb 2017 21:20:52 +0100 Subject: [PATCH 026/159] server/image-hash: do not depend on image-match While I hold this library in great esteem for its excellent work on implementing the original paper, I have several problems with it: - as of this commit, it (again) has bug fixes unreleased on pip - its code is badly structured - forces OOP and then proceeds @staticmethod everything - bad class design, parameters are repeated in several places - terrible contract of make_record() and generate_signature() - ambiguous parameters: path vs. image path vs. image content - doesn't adhere to PEP-8 - depends on cairo just to render svg images almost no one uses this library with --- server/requirements.txt | 1 - server/szurubooru/func/image_hash.py | 299 +++++++++++++++--- server/szurubooru/func/posts.py | 6 +- .../szurubooru/tests/func/test_image_hash.py | 6 +- 4 files changed, 267 insertions(+), 45 deletions(-) diff --git a/server/requirements.txt b/server/requirements.txt index c200222d..bef7c14d 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -7,7 +7,6 @@ pytest-cov>=2.2.1 freezegun>=0.3.6 coloredlogs==5.0 pycodestyle>=2.0.0 -image-match>=1.1.0 scipy>=0.18.1 elasticsearch>=5.0.0 elasticsearch-dsl>=5.0.0 diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index ae368488..c6fc8403 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -1,11 +1,13 @@ import logging +from io import BytesIO +from datetime import datetime import elasticsearch import elasticsearch_dsl -import xml.etree -from image_match.elasticsearch_driver import SignatureES +import numpy as np +from skimage.color import rgb2gray +from PIL import Image from szurubooru import config, errors - # pylint: disable=invalid-name logger = logging.getLogger(__name__) es = elasticsearch.Elasticsearch([{ @@ -14,11 +16,190 @@ es = elasticsearch.Elasticsearch([{ }]) -def _get_session(): - return SignatureES(es, index=config.config['elasticsearch']['index']) +# Math based on paper from H. Chi Wong, Marshall Bern and David Goldber +# Math code taken from https://github.com/ascribe/image-match +# (which is licensed under Apache 2 license) + +LOWER_PERCENTILE = 5 +UPPER_PERCENTILE = 95 +IDENTICAL_TOLERANCE = 2 / 255. +DISTANCE_CUTOFF = 0.45 +N_LEVELS = 2 +N = 9 +P = None +SAMPLE_WORDS = 16 +MAX_WORDS = 63 +ES_DOC_TYPE = 'image' +ES_MAX_RESULTS = 100 -def _safe_blanket(default_param_factory): +def _preprocess_image(image_or_path): + img = Image.open(BytesIO(image_or_path)) + img = img.convert('RGB') + return rgb2gray(np.asarray(img, dtype=np.uint8)) + + +def _crop_image(image, lower_percentile, upper_percentile): + rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1)) + cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0)) + upper_column_limit = np.searchsorted( + cw, np.percentile(cw, upper_percentile), side='left') + lower_column_limit = np.searchsorted( + cw, np.percentile(cw, lower_percentile), side='right') + upper_row_limit = np.searchsorted( + rw, np.percentile(rw, upper_percentile), side='left') + lower_row_limit = np.searchsorted( + rw, np.percentile(rw, lower_percentile), side='right') + if lower_row_limit > upper_row_limit: + lower_row_limit = int(lower_percentile / 100. * image.shape[0]) + upper_row_limit = int(upper_percentile / 100. * image.shape[0]) + if lower_column_limit > upper_column_limit: + lower_column_limit = int(lower_percentile / 100. * image.shape[1]) + upper_column_limit = int(upper_percentile / 100. * image.shape[1]) + return [ + (lower_row_limit, upper_row_limit), + (lower_column_limit, upper_column_limit)] + + +def _normalize_and_threshold(diff_array, identical_tolerance, n_levels): + mask = np.abs(diff_array) < identical_tolerance + diff_array[mask] = 0. + if np.all(mask): + return None + positive_cutoffs = np.percentile( + diff_array[diff_array > 0.], np.linspace(0, 100, n_levels+1)) + negative_cutoffs = np.percentile( + diff_array[diff_array < 0.], np.linspace(100, 0, n_levels+1)) + for level, interval in enumerate( + positive_cutoffs[i:i+2] + for i in range(positive_cutoffs.shape[0] - 1)): + diff_array[ + (diff_array >= interval[0]) & (diff_array <= interval[1])] = \ + level + 1 + for level, interval in enumerate( + negative_cutoffs[i:i+2] + for i in range(negative_cutoffs.shape[0] - 1)): + diff_array[ + (diff_array <= interval[0]) & (diff_array >= interval[1])] = \ + -(level + 1) + return None + + +def _compute_grid_points(image, n, window=None): + if window is None: + window = [(0, image.shape[0]), (0, image.shape[1])] + x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1] + y_coords = np.linspace(window[1][0], window[1][1], n + 2, dtype=int)[1:-1] + return x_coords, y_coords + + +def _compute_mean_level(image, x_coords, y_coords, p): + if p is None: + p = max([2.0, int(0.5 + min(image.shape) / 20.)]) + avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0])) + for i, x in enumerate(x_coords): + lower_x_lim = int(max([x - p / 2, 0])) + upper_x_lim = int(min([lower_x_lim + p, image.shape[0]])) + for j, y in enumerate(y_coords): + lower_y_lim = int(max([y - p / 2, 0])) + upper_y_lim = int(min([lower_y_lim + p, image.shape[1]])) + avg_grey[i, j] = np.mean( + image[lower_x_lim:upper_x_lim, lower_y_lim:upper_y_lim]) + return avg_grey + + +def _compute_differentials(grey_level_matrix): + flipped = np.fliplr(grey_level_matrix) + right_neighbors = -np.concatenate(( + np.diff(grey_level_matrix), + np.zeros(grey_level_matrix.shape[0]) + .reshape((grey_level_matrix.shape[0], 1))), axis=1) + down_neighbors = -np.concatenate(( + np.diff(grey_level_matrix, axis=0), + np.zeros(grey_level_matrix.shape[1]) + .reshape((1, grey_level_matrix.shape[1])))) + left_neighbors = -np.concatenate( + (right_neighbors[:, -1:], right_neighbors[:, :-1]), axis=1) + up_neighbors = -np.concatenate((down_neighbors[-1:], down_neighbors[:-1])) + diagonals = np.arange( + -grey_level_matrix.shape[0] + 1, grey_level_matrix.shape[0]) + upper_left_neighbors = sum([ + np.diagflat(np.insert(np.diff(np.diag(grey_level_matrix, i)), 0, 0), i) + for i in diagonals]) + upper_right_neighbors = sum([ + np.diagflat(np.insert(np.diff(np.diag(flipped, i)), 0, 0), i) + for i in diagonals]) + lower_right_neighbors = -np.pad( + upper_left_neighbors[1:, 1:], (0, 1), mode='constant') + lower_left_neighbors = -np.pad( + upper_right_neighbors[1:, 1:], (0, 1), mode='constant') + return np.dstack(np.array([ + upper_left_neighbors, + up_neighbors, + np.fliplr(upper_right_neighbors), + left_neighbors, + right_neighbors, + np.fliplr(lower_left_neighbors), + down_neighbors, + lower_right_neighbors])) + + +def _generate_signature(path_or_image): + im_array = _preprocess_image(path_or_image) + image_limits = _crop_image(im_array, + lower_percentile=LOWER_PERCENTILE, + upper_percentile=UPPER_PERCENTILE) + x_coords, y_coords = _compute_grid_points( + im_array, n=N, window=image_limits) + avg_grey = _compute_mean_level(im_array, x_coords, y_coords, p=P) + diff_matrix = _compute_differentials(avg_grey) + _normalize_and_threshold(diff_matrix, + identical_tolerance=IDENTICAL_TOLERANCE, n_levels=N_LEVELS) + return np.ravel(diff_matrix).astype('int8') + + +def _get_words(array, k, n): + word_positions = np.linspace( + 0, array.shape[0], n, endpoint=False).astype('int') + assert k <= array.shape[0] + assert word_positions.shape[0] <= array.shape[0] + words = np.zeros((n, k)).astype('int8') + for i, pos in enumerate(word_positions): + if pos + k <= array.shape[0]: + words[i] = array[pos:pos+k] + else: + temp = array[pos:].copy() + temp.resize(k) + words[i] = temp + _max_contrast(words) + words = _words_to_int(words) + return words + + +def _words_to_int(word_array): + width = word_array.shape[1] + coding_vector = 3**np.arange(width) + return np.dot(word_array + 1, coding_vector) + + +def _max_contrast(array): + array[array > 0] = 1 + array[array < 0] = -1 + return None + + +def _normalized_distance(_target_array, _vec, nan_value=1.0): + target_array = _target_array.astype(int) + vec = _vec.astype(int) + topvec = np.linalg.norm(vec - target_array, axis=1) + norm1 = np.linalg.norm(vec, axis=0) + norm2 = np.linalg.norm(target_array, axis=1) + finvec = topvec / (norm1 + norm2) + finvec[np.isnan(finvec)] = nan_value + return finvec + + +def _safety_blanket(default_param_factory): def wrapper_outer(target_function): def wrapper_inner(*args, **kwargs): try: @@ -28,14 +209,13 @@ def _safe_blanket(default_param_factory): # add_image() return default_param_factory() except elasticsearch.exceptions.ElasticsearchException as ex: - logger.warning('Problem with elastic search: %s' % ex) + logger.warning('Problem with elastic search: %s', ex) raise errors.ThirdPartyError( 'Error connecting to elastic search.') - except xml.etree.ElementTree.ParseError as ex: - # image-match issue #60 + except IOError: raise errors.ProcessingError('Not an image.') except Exception as ex: - raise errors.ThirdPartyError('Unknown error (%s).' % ex) + raise errors.ThirdPartyError('Unknown error (%s).', ex) return wrapper_inner return wrapper_outer @@ -47,53 +227,96 @@ class Lookalike: self.path = path -@_safe_blanket(lambda: None) +@_safety_blanket(lambda: None) def add_image(path, image_content): - if not path or not image_content: - return - session = _get_session() - session.add_image(path=path, img=image_content, bytestream=True) + assert path + assert image_content + signature = _generate_signature(image_content) + words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) + + record = { + 'signature': signature.tolist(), + 'path': path, + 'timestamp': datetime.now(), + } + for i in range(MAX_WORDS): + record['simple_word_' + str(i)] = words[i].tolist() + + es.index( + index=config.config['elasticsearch']['index'], + doc_type=ES_DOC_TYPE, + body=record, + refresh=True) -@_safe_blanket(lambda: None) +@_safety_blanket(lambda: None) def delete_image(path): - if not path: - return - session = _get_session() + assert path es.delete_by_query( - index=session.index, - doc_type=session.doc_type, + index=config.config['elasticsearch']['index'], + doc_type=ES_DOC_TYPE, body={'query': {'term': {'path': path}}}) -@_safe_blanket(lambda: []) +@_safety_blanket(lambda: []) def search_by_image(image_content): + signature = _generate_signature(image_content) + words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) + + res = es.search( + index=config.config['elasticsearch']['index'], + doc_type=ES_DOC_TYPE, + body={ + 'query': + { + 'bool': + { + 'should': + [ + {'term': {'simple_word_%d' % i: word.tolist()}} + for i, word in enumerate(words) + ] + } + }, + '_source': {'excludes': ['simple_word_*']}}, + size=ES_MAX_RESULTS, + timeout='10s')['hits']['hits'] + + if len(res) == 0: + return [] + + sigs = np.array([x['_source']['signature'] for x in res]) + dists = _normalized_distance(sigs, np.array(signature)) + + ids = set() ret = [] - session = _get_session() - for result in session.search_image( - path=image_content, # sic - bytestream=True): - ret.append(Lookalike( - score=result['score'], - distance=result['dist'], - path=result['path'])) + for item, dist in zip(res, dists): + id = item['_id'] + score = item['_score'] + path = item['_source']['path'] + if id in ids: + continue + ids.add(id) + if dist < DISTANCE_CUTOFF: + ret.append(Lookalike(score=score, distance=dist, path=path)) return ret -@_safe_blanket(lambda: None) +@_safety_blanket(lambda: None) def purge(): - session = _get_session() es.delete_by_query( - index=session.index, - doc_type=session.doc_type, - body={'query': {'match_all': {}}}) + index=config.config['elasticsearch']['index'], + doc_type=ES_DOC_TYPE, + body={'query': {'match_all': {}}}, + refresh=True) -@_safe_blanket(lambda: set()) +@_safety_blanket(lambda: set()) def get_all_paths(): - session = _get_session() search = ( elasticsearch_dsl.Search( - using=es, index=session.index, doc_type=session.doc_type) + using=es, + index=config.config['elasticsearch']['index'], + doc_type=ES_DOC_TYPE) .source(['path'])) return set(h.path for h in search.scan()) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 51d11f31..7c85e8ac 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -268,7 +268,8 @@ def _after_post_update(_mapper, _connection, post): @sqlalchemy.events.event.listens_for(db.Post, 'before_delete') def _before_post_delete(_mapper, _connection, post): - image_hash.delete_image(post.post_id) + if post.post_id: + image_hash.delete_image(post.post_id) def _sync_post_content(post): @@ -279,7 +280,8 @@ def _sync_post_content(post): files.save(get_post_content_path(post), content) delattr(post, '__content') regenerate_thumb = True - if post.type in (db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): + if post.post_id and post.type in ( + db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): image_hash.delete_image(post.post_id) image_hash.add_image(post.post_id, content) diff --git a/server/szurubooru/tests/func/test_image_hash.py b/server/szurubooru/tests/func/test_image_hash.py index 3bbdeba2..becba906 100644 --- a/server/szurubooru/tests/func/test_image_hash.py +++ b/server/szurubooru/tests/func/test_image_hash.py @@ -1,4 +1,3 @@ -from time import sleep from szurubooru.func import image_hash @@ -7,11 +6,10 @@ def test_hashing(read_asset, config_injector): image_hash.purge() image_hash.add_image('test', read_asset('jpeg.jpg')) - sleep(0.1) - paths = image_hash.get_all_paths() results_exact = image_hash.search_by_image(read_asset('jpeg.jpg')) - results_similar = image_hash.search_by_image(read_asset('jpeg-similar.jpg')) + results_similar = image_hash.search_by_image( + read_asset('jpeg-similar.jpg')) assert len(paths) == 1 assert len(results_exact) == 1 From abf1fc2b2d135299fe7e8a700d4e7a18966d7bfe Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 3 Feb 2017 21:42:15 +0100 Subject: [PATCH 027/159] server: make linters happier --- server/lint | 2 +- server/szurubooru/api/info_api.py | 4 +- server/szurubooru/db/post.py | 4 +- server/szurubooru/db/tag.py | 18 ++++--- server/szurubooru/db/user.py | 15 ++++-- server/szurubooru/facade.py | 2 +- server/szurubooru/func/image_hash.py | 47 ++++++++++++------- server/szurubooru/func/images.py | 3 +- server/szurubooru/func/posts.py | 27 +++++++---- server/szurubooru/func/snapshots.py | 11 +++-- server/szurubooru/func/tags.py | 40 ++++++++++------ server/szurubooru/func/users.py | 10 ++-- server/szurubooru/func/util.py | 5 +- ...b3_add_default_column_to_tag_categories.py | 4 +- .../versions/7f6baf38c27c_add_versions.py | 4 +- server/szurubooru/rest/app.py | 2 +- server/szurubooru/rest/context.py | 9 ++-- .../search/configs/post_search_config.py | 2 +- server/szurubooru/search/configs/util.py | 3 +- server/szurubooru/search/executor.py | 20 ++++---- .../tests/api/test_post_retrieving.py | 5 +- .../tests/api/test_tag_siblings_retrieving.py | 5 +- server/szurubooru/tests/func/test_posts.py | 5 +- server/szurubooru/tests/func/test_users.py | 12 ++--- 24 files changed, 153 insertions(+), 106 deletions(-) diff --git a/server/lint b/server/lint index b53f74f4..218d3bb4 100755 --- a/server/lint +++ b/server/lint @@ -1,3 +1,3 @@ #!/bin/sh pylint szurubooru -pycodestyle szurubooru --ignore=E128,E131,W503 +pycodestyle szurubooru diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index d815485a..c0d2a955 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -33,11 +33,11 @@ def get_info(ctx, _params=None): 'diskUsage': _get_disk_usage(), 'featuredPost': posts.serialize_post(post_feature.post, ctx.user) - if post_feature else None, + if post_feature else None, 'featuringTime': post_feature.time if post_feature else None, 'featuringUser': users.serialize_user(post_feature.user, ctx.user) - if post_feature else None, + if post_feature else None, 'serverTime': datetime.datetime.utcnow(), 'config': { 'userNameRegex': config.config['user_name_regex'], diff --git a/server/szurubooru/db/post.py b/server/szurubooru/db/post.py index 8be7e8b5..f0c9f91f 100644 --- a/server/szurubooru/db/post.py +++ b/server/szurubooru/db/post.py @@ -252,8 +252,8 @@ class Post(Base): relation_count = column_property( select([func.count(PostRelation.child_id)]) .where( - (PostRelation.parent_id == post_id) - | (PostRelation.child_id == post_id)) + (PostRelation.parent_id == post_id) | + (PostRelation.child_id == post_id)) .correlate_except(PostRelation)) __mapper_args__ = { diff --git a/server/szurubooru/db/tag.py b/server/szurubooru/db/tag.py index e0861312..10813eb9 100644 --- a/server/szurubooru/db/tag.py +++ b/server/szurubooru/db/tag.py @@ -106,23 +106,29 @@ class Tag(Base): .correlate_except(PostTag)) first_name = column_property( - select([TagName.name]) + ( + select([TagName.name]) .where(TagName.tag_id == tag_id) .order_by(TagName.order) .limit(1) - .as_scalar(), + .as_scalar() + ), deferred=True) suggestion_count = column_property( - select([func.count(TagSuggestion.child_id)]) + ( + select([func.count(TagSuggestion.child_id)]) .where(TagSuggestion.parent_id == tag_id) - .as_scalar(), + .as_scalar() + ), deferred=True) implication_count = column_property( - select([func.count(TagImplication.child_id)]) + ( + select([func.count(TagImplication.child_id)]) .where(TagImplication.parent_id == tag_id) - .as_scalar(), + .as_scalar() + ), deferred=True) __mapper_args__ = { diff --git a/server/szurubooru/db/user.py b/server/szurubooru/db/user.py index 082adcff..4f4f9961 100644 --- a/server/szurubooru/db/user.py +++ b/server/szurubooru/db/user.py @@ -37,7 +37,8 @@ class User(Base): @property def post_count(self): from szurubooru.db import session - return (session + return ( + session .query(func.sum(1)) .filter(Post.user_id == self.user_id) .one()[0] or 0) @@ -45,7 +46,8 @@ class User(Base): @property def comment_count(self): from szurubooru.db import session - return (session + return ( + session .query(func.sum(1)) .filter(Comment.user_id == self.user_id) .one()[0] or 0) @@ -53,7 +55,8 @@ class User(Base): @property def favorite_post_count(self): from szurubooru.db import session - return (session + return ( + session .query(func.sum(1)) .filter(PostFavorite.user_id == self.user_id) .one()[0] or 0) @@ -61,7 +64,8 @@ class User(Base): @property def liked_post_count(self): from szurubooru.db import session - return (session + return ( + session .query(func.sum(1)) .filter(PostScore.user_id == self.user_id) .filter(PostScore.score == 1) @@ -70,7 +74,8 @@ class User(Base): @property def disliked_post_count(self): from szurubooru.db import session - return (session + return ( + session .query(func.sum(1)) .filter(PostScore.user_id == self.user_id) .filter(PostScore.score == -1) diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index c85ac1ae..48957a1f 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -127,4 +127,4 @@ def create_app(): return rest.application -app = create_app() +app = create_app() # pylint: disable=invalid-name diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index c6fc8403..c89a2ec1 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -67,21 +67,21 @@ def _normalize_and_threshold(diff_array, identical_tolerance, n_levels): if np.all(mask): return None positive_cutoffs = np.percentile( - diff_array[diff_array > 0.], np.linspace(0, 100, n_levels+1)) + diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1)) negative_cutoffs = np.percentile( - diff_array[diff_array < 0.], np.linspace(100, 0, n_levels+1)) + diff_array[diff_array < 0.], np.linspace(100, 0, n_levels + 1)) for level, interval in enumerate( - positive_cutoffs[i:i+2] + positive_cutoffs[i:i + 2] for i in range(positive_cutoffs.shape[0] - 1)): diff_array[ (diff_array >= interval[0]) & (diff_array <= interval[1])] = \ - level + 1 + level + 1 for level, interval in enumerate( - negative_cutoffs[i:i+2] + negative_cutoffs[i:i + 2] for i in range(negative_cutoffs.shape[0] - 1)): diff_array[ (diff_array <= interval[0]) & (diff_array >= interval[1])] = \ - -(level + 1) + -(level + 1) return None @@ -110,14 +110,22 @@ def _compute_mean_level(image, x_coords, y_coords, p): def _compute_differentials(grey_level_matrix): flipped = np.fliplr(grey_level_matrix) - right_neighbors = -np.concatenate(( - np.diff(grey_level_matrix), - np.zeros(grey_level_matrix.shape[0]) - .reshape((grey_level_matrix.shape[0], 1))), axis=1) - down_neighbors = -np.concatenate(( - np.diff(grey_level_matrix, axis=0), - np.zeros(grey_level_matrix.shape[1]) - .reshape((1, grey_level_matrix.shape[1])))) + right_neighbors = -np.concatenate( + ( + np.diff(grey_level_matrix), + ( + np.zeros(grey_level_matrix.shape[0]) + .reshape((grey_level_matrix.shape[0], 1)) + ) + ), axis=1) + down_neighbors = -np.concatenate( + ( + np.diff(grey_level_matrix, axis=0), + ( + np.zeros(grey_level_matrix.shape[1]) + .reshape((1, grey_level_matrix.shape[1])) + ) + )) left_neighbors = -np.concatenate( (right_neighbors[:, -1:], right_neighbors[:, :-1]), axis=1) up_neighbors = -np.concatenate((down_neighbors[-1:], down_neighbors[:-1])) @@ -146,15 +154,18 @@ def _compute_differentials(grey_level_matrix): def _generate_signature(path_or_image): im_array = _preprocess_image(path_or_image) - image_limits = _crop_image(im_array, + image_limits = _crop_image( + im_array, lower_percentile=LOWER_PERCENTILE, upper_percentile=UPPER_PERCENTILE) x_coords, y_coords = _compute_grid_points( im_array, n=N, window=image_limits) avg_grey = _compute_mean_level(im_array, x_coords, y_coords, p=P) diff_matrix = _compute_differentials(avg_grey) - _normalize_and_threshold(diff_matrix, - identical_tolerance=IDENTICAL_TOLERANCE, n_levels=N_LEVELS) + _normalize_and_threshold( + diff_matrix, + identical_tolerance=IDENTICAL_TOLERANCE, + n_levels=N_LEVELS) return np.ravel(diff_matrix).astype('int8') @@ -166,7 +177,7 @@ def _get_words(array, k, n): words = np.zeros((n, k)).astype('int8') for i, pos in enumerate(word_positions): if pos + k <= array.shape[0]: - words[i] = array[pos:pos+k] + words[i] = array[pos:pos + k] else: temp = array[pos:].copy() temp.resize(k) diff --git a/server/szurubooru/func/images.py b/server/szurubooru/func/images.py index a04d20dd..fdab793b 100644 --- a/server/szurubooru/func/images.py +++ b/server/szurubooru/func/images.py @@ -111,4 +111,5 @@ class Image: assert 'streams' in self.info if len(self.info['streams']) < 1: logger.warning('The video contains no video streams.') - raise errors.ProcessingError('The video contains no video streams.') + raise errors.ProcessingError( + 'The video contains no video streams.') diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 7c85e8ac..c942e799 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -474,12 +474,15 @@ def merge_posts(source_post, target_post, replace_content): def merge_tables(table, anti_dup_func, source_post_id, target_post_id): alias1 = table alias2 = sqlalchemy.orm.util.aliased(table) - update_stmt = (sqlalchemy.sql.expression.update(alias1) + update_stmt = ( + sqlalchemy.sql.expression.update(alias1) .where(alias1.post_id == source_post_id)) if anti_dup_func is not None: - update_stmt = (update_stmt - .where(~sqlalchemy.exists() + update_stmt = ( + update_stmt + .where( + ~sqlalchemy.exists() .where(anti_dup_func(alias1, alias2)) .where(alias2.post_id == target_post_id))) @@ -513,19 +516,23 @@ def merge_posts(source_post, target_post, replace_content): def merge_relations(source_post_id, target_post_id): alias1 = db.PostRelation alias2 = sqlalchemy.orm.util.aliased(db.PostRelation) - update_stmt = (sqlalchemy.sql.expression.update(alias1) + update_stmt = ( + sqlalchemy.sql.expression.update(alias1) .where(alias1.parent_id == source_post_id) .where(alias1.child_id != target_post_id) - .where(~sqlalchemy.exists() + .where( + ~sqlalchemy.exists() .where(alias2.child_id == alias1.child_id) .where(alias2.parent_id == target_post_id)) .values(parent_id=target_post_id)) db.session.execute(update_stmt) - update_stmt = (sqlalchemy.sql.expression.update(alias1) + update_stmt = ( + sqlalchemy.sql.expression.update(alias1) .where(alias1.child_id == source_post_id) .where(alias1.parent_id != target_post_id) - .where(~sqlalchemy.exists() + .where( + ~sqlalchemy.exists() .where(alias2.parent_id == alias1.parent_id) .where(alias2.child_id == target_post_id)) .values(child_id=target_post_id)) @@ -567,7 +574,8 @@ def search_by_image(image_content): def populate_reverse_search(): excluded_post_ids = image_hash.get_all_paths() - post_ids_to_hash = (db.session + post_ids_to_hash = ( + db.session .query(db.Post.post_id) .filter( (db.Post.type == db.Post.TYPE_IMAGE) | @@ -577,7 +585,8 @@ def populate_reverse_search(): .all()) for post_ids_chunk in util.chunks(post_ids_to_hash, 100): - posts_chunk = (db.session + posts_chunk = ( + db.session .query(db.Post) .filter(db.Post.post_id.in_(post_ids_chunk)) .all()) diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index 4c236bb5..f7efda9e 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -86,10 +86,13 @@ def create(entity, auth_user): def modify(entity, auth_user): assert entity - model = next((model - for model in db.Base._decl_class_registry.values() - if hasattr(model, '__table__') - and model.__table__.fullname == entity.__table__.fullname), + model = next( + ( + model + for model in db.Base._decl_class_registry.values() + if hasattr(model, '__table__') + and model.__table__.fullname == entity.__table__.fullname + ), None) assert model diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index ece1231d..1665282b 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -104,7 +104,8 @@ def export_to_json(): 'color': result[2], } - for result in (db.session + for result in ( + db.session .query(db.TagName.tag_id, db.TagName.name) .order_by(db.TagName.order) .all()): @@ -112,7 +113,8 @@ def export_to_json(): tags[result[0]] = {'names': []} tags[result[0]]['names'].append(result[1]) - for result in (db.session + for result in ( + db.session .query(db.TagSuggestion.parent_id, db.TagName.name) .join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) .all()): @@ -120,7 +122,8 @@ def export_to_json(): tags[result[0]]['suggestions'] = [] tags[result[0]]['suggestions'].append(result[1]) - for result in (db.session + for result in ( + db.session .query(db.TagImplication.parent_id, db.TagName.name) .join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) .all()): @@ -146,7 +149,8 @@ def export_to_json(): def try_get_tag_by_name(name): - return (db.session + return ( + db.session .query(db.Tag) .join(db.TagName) .filter(sqlalchemy.func.lower(db.TagName.name) == name.lower()) @@ -198,7 +202,8 @@ def get_tag_siblings(tag): tag_alias = sqlalchemy.orm.aliased(db.Tag) pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) - result = (db.session + result = ( + db.session .query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) @@ -214,10 +219,10 @@ def delete(source_tag): assert source_tag db.session.execute( sqlalchemy.sql.expression.delete(db.TagSuggestion) - .where(db.TagSuggestion.child_id == source_tag.tag_id)) + .where(db.TagSuggestion.child_id == source_tag.tag_id)) db.session.execute( sqlalchemy.sql.expression.delete(db.TagImplication) - .where(db.TagImplication.child_id == source_tag.tag_id)) + .where(db.TagImplication.child_id == source_tag.tag_id)) db.session.delete(source_tag) @@ -230,10 +235,13 @@ def merge_tags(source_tag, target_tag): def merge_posts(source_tag_id, target_tag_id): alias1 = db.PostTag alias2 = sqlalchemy.orm.util.aliased(db.PostTag) - update_stmt = (sqlalchemy.sql.expression.update(alias1) + update_stmt = ( + sqlalchemy.sql.expression.update(alias1) .where(alias1.tag_id == source_tag_id)) - update_stmt = (update_stmt - .where(~sqlalchemy.exists() + update_stmt = ( + update_stmt + .where( + ~sqlalchemy.exists() .where(alias1.post_id == alias2.post_id) .where(alias2.tag_id == target_tag_id))) update_stmt = update_stmt.values(tag_id=target_tag_id) @@ -242,19 +250,23 @@ def merge_tags(source_tag, target_tag): def merge_relations(table, source_tag_id, target_tag_id): alias1 = table alias2 = sqlalchemy.orm.util.aliased(table) - update_stmt = (sqlalchemy.sql.expression.update(alias1) + update_stmt = ( + sqlalchemy.sql.expression.update(alias1) .where(alias1.parent_id == source_tag_id) .where(alias1.child_id != target_tag_id) - .where(~sqlalchemy.exists() + .where( + ~sqlalchemy.exists() .where(alias2.child_id == alias1.child_id) .where(alias2.parent_id == target_tag_id)) .values(parent_id=target_tag_id)) db.session.execute(update_stmt) - update_stmt = (sqlalchemy.sql.expression.update(alias1) + update_stmt = ( + sqlalchemy.sql.expression.update(alias1) .where(alias1.child_id == source_tag_id) .where(alias1.parent_id != target_tag_id) - .where(~sqlalchemy.exists() + .where( + ~sqlalchemy.exists() .where(alias2.parent_id == alias1.parent_id) .where(alias2.child_id == target_tag_id)) .values(child_id=target_tag_id)) diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index c644d4a2..5547bbae 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -44,10 +44,9 @@ def get_avatar_url(user): return 'https://gravatar.com/avatar/%s?d=retro&s=%d' % ( util.get_md5((user.email or user.name).lower()), config.config['thumbnails']['avatar_width']) - else: - assert user.name - return '%s/avatars/%s.png' % ( - config.config['data_url'].rstrip('/'), user.name.lower()) + assert user.name + return '%s/avatars/%s.png' % ( + config.config['data_url'].rstrip('/'), user.name.lower()) def get_email(user, auth_user, force_show_email): @@ -126,7 +125,8 @@ def get_user_by_name(name): def try_get_user_by_name_or_email(name_or_email): - return (db.session + return ( + db.session .query(db.User) .filter( (func.lower(db.User.name) == func.lower(name_or_email)) | diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index 497878ee..11caedd2 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -102,6 +102,7 @@ def parse_time_range(value): ''' Return tuple containing min/max time for given text representation. ''' one_day = timedelta(days=1) one_second = timedelta(seconds=1) + almost_one_day = one_day - one_second value = value.lower() if not value: @@ -111,8 +112,8 @@ def parse_time_range(value): now = datetime.utcnow() return ( datetime(now.year, now.month, now.day, 0, 0, 0), - datetime(now.year, now.month, now.day, 0, 0, 0) - + one_day - one_second) + datetime(now.year, now.month, now.day, 0, 0, 0) + almost_one_day + ) if value == 'yesterday': now = datetime.utcnow() diff --git a/server/szurubooru/migrations/versions/055d0e048fb3_add_default_column_to_tag_categories.py b/server/szurubooru/migrations/versions/055d0e048fb3_add_default_column_to_tag_categories.py index e6a37e6d..1ced1596 100644 --- a/server/szurubooru/migrations/versions/055d0e048fb3_add_default_column_to_tag_categories.py +++ b/server/szurubooru/migrations/versions/055d0e048fb3_add_default_column_to_tag_categories.py @@ -19,8 +19,8 @@ def upgrade(): 'tag_category', sa.Column('default', sa.Boolean(), nullable=True)) op.execute( sa.table('tag_category', sa.column('default')) - .update() - .values(default=False)) + .update() + .values(default=False)) op.alter_column('tag_category', 'default', nullable=False) diff --git a/server/szurubooru/migrations/versions/7f6baf38c27c_add_versions.py b/server/szurubooru/migrations/versions/7f6baf38c27c_add_versions.py index 71e6bcb0..22360260 100644 --- a/server/szurubooru/migrations/versions/7f6baf38c27c_add_versions.py +++ b/server/szurubooru/migrations/versions/7f6baf38c27c_add_versions.py @@ -21,8 +21,8 @@ def upgrade(): op.add_column(table, sa.Column('version', sa.Integer(), nullable=True)) op.execute( sa.table(table, sa.column('version')) - .update() - .values(version=1)) + .update() + .values(version=1)) op.alter_column(table, 'version', nullable=False) diff --git a/server/szurubooru/rest/app.py b/server/szurubooru/rest/app.py index 391d297f..1bbf8dce 100644 --- a/server/szurubooru/rest/app.py +++ b/server/szurubooru/rest/app.py @@ -93,7 +93,7 @@ def application(env, start_response): hook(ctx) try: response = handler(ctx, match.groupdict()) - except: + except Exception: ctx.session.rollback() raise finally: diff --git a/server/szurubooru/rest/context.py b/server/szurubooru/rest/context.py index 081064ed..ae26f38b 100644 --- a/server/szurubooru/rest/context.py +++ b/server/szurubooru/rest/context.py @@ -44,9 +44,10 @@ class Context: return self._headers.get(name, None) def has_file(self, name, allow_tokens=True): - return (name in self._files - or name + 'Url' in self._params - or (allow_tokens and name + 'Token' in self._params)) + return ( + name in self._files or + name + 'Url' in self._params or + (allow_tokens and name + 'Token' in self._params)) def get_file(self, name, required=False, allow_tokens=True): ret = None @@ -80,7 +81,7 @@ class Context: if isinstance(value, list): try: value = ','.join(value) - except: + except TypeError: raise errors.InvalidParameterError('Expected simple string.') return value diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index cf02fecc..7005cd7c 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -46,7 +46,7 @@ def _create_score_filter(score): if not getattr(criterion, 'internal', False): raise errors.SearchError( 'Votes cannot be seen publicly. Did you mean %r?' - % 'special:liked') + % 'special:liked') user_alias = aliased(db.User) score_alias = aliased(db.PostScore) expr = score_alias.score == score diff --git a/server/szurubooru/search/configs/util.py b/server/szurubooru/search/configs/util.py index 7cb36e9f..2eaaf8d7 100644 --- a/server/szurubooru/search/configs/util.py +++ b/server/szurubooru/search/configs/util.py @@ -5,7 +5,8 @@ from szurubooru.search import criteria def wildcard_transformer(value): - return (value + return ( + value .replace('\\', '\\\\') .replace('%', '\\%') .replace('_', '\\_') diff --git a/server/szurubooru/search/executor.py b/server/szurubooru/search/executor.py index e1ee53ad..d9adc940 100644 --- a/server/szurubooru/search/executor.py +++ b/server/szurubooru/search/executor.py @@ -41,18 +41,18 @@ class Executor: filter_query, search_query, False) prev_filter_query = ( filter_query - .filter(self.config.id_column > entity_id) - .order_by(None) - .order_by(sqlalchemy.func.abs( - self.config.id_column - entity_id).asc()) - .limit(1)) + .filter(self.config.id_column > entity_id) + .order_by(None) + .order_by(sqlalchemy.func.abs( + self.config.id_column - entity_id).asc()) + .limit(1)) next_filter_query = ( filter_query - .filter(self.config.id_column < entity_id) - .order_by(None) - .order_by(sqlalchemy.func.abs( - self.config.id_column - entity_id).asc()) - .limit(1)) + .filter(self.config.id_column < entity_id) + .order_by(None) + .order_by(sqlalchemy.func.abs( + self.config.id_column - entity_id).asc()) + .limit(1)) return [ prev_filter_query.one_or_none(), next_filter_query.one_or_none()] diff --git a/server/szurubooru/tests/api/test_post_retrieving.py b/server/szurubooru/tests/api/test_post_retrieving.py index a33d60a3..a02c7bc1 100644 --- a/server/szurubooru/tests/api/test_post_retrieving.py +++ b/server/szurubooru/tests/api/test_post_retrieving.py @@ -44,9 +44,8 @@ def test_using_special_tokens(user_factory, post_factory, context_factory): db.session.add_all([post1, post2, auth_user]) db.session.flush() with patch('szurubooru.func.posts.serialize_post'): - posts.serialize_post.side_effect = \ - lambda post, *_args, **_kwargs: \ - 'serialized post %d' % post.post_id + posts.serialize_post.side_effect = lambda post, *_args, **_kwargs: \ + 'serialized post %d' % post.post_id result = api.post_api.get_posts( context_factory( params={'query': 'special:fav', 'page': 1}, diff --git a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py index 6ba00868..6de25fcc 100644 --- a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py @@ -14,9 +14,8 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory): db.session.flush() with patch('szurubooru.func.tags.serialize_tag'), \ patch('szurubooru.func.tags.get_tag_siblings'): - tags.serialize_tag.side_effect = \ - lambda tag, *args, **kwargs: \ - 'serialized tag %s' % tag.names[0].name + tags.serialize_tag.side_effect = lambda tag, *args, **kwargs: \ + 'serialized tag %s' % tag.names[0].name tags.get_tag_siblings.return_value = [ (tag_factory(names=['sib1']), 1), (tag_factory(names=['sib2']), 3), diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index a10cc6d9..682a1ccc 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -494,9 +494,8 @@ def test_update_post_content_leaving_custom_thumbnail( def test_update_post_tags(tag_factory): post = db.Post() with patch('szurubooru.func.tags.get_or_create_tags_by_names'): - tags.get_or_create_tags_by_names.side_effect \ - = lambda tag_names: \ - ([tag_factory(names=[name]) for name in tag_names], []) + tags.get_or_create_tags_by_names.side_effect = lambda tag_names: \ + ([tag_factory(names=[name]) for name in tag_names], []) posts.update_post_tags(post, ['tag1', 'tag2']) assert len(post.tags) == 2 assert post.tags[0].names[0].name == 'tag1' diff --git a/server/szurubooru/tests/func/test_users.py b/server/szurubooru/tests/func/test_users.py index ce2a40ec..73150bb2 100644 --- a/server/szurubooru/tests/func/test_users.py +++ b/server/szurubooru/tests/func/test_users.py @@ -21,22 +21,22 @@ def test_get_avatar_path(user_name): 'user', None, db.User.AVATAR_GRAVATAR, - 'https://gravatar.com/avatar/' + - 'ee11cbb19052e40b07aac0ca060c23ee?d=retro&s=100', + ('https://gravatar.com/avatar/' + + 'ee11cbb19052e40b07aac0ca060c23ee?d=retro&s=100'), ), ( None, 'user@example.com', db.User.AVATAR_GRAVATAR, - 'https://gravatar.com/avatar/' + - 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100', + ('https://gravatar.com/avatar/' + + 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100'), ), ( 'user', 'user@example.com', db.User.AVATAR_GRAVATAR, - 'https://gravatar.com/avatar/' + - 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100', + ('https://gravatar.com/avatar/' + + 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100'), ), ( 'user', From ad842ee8a54c57463b8e28b52970f173ea1d64ea Mon Sep 17 00:00:00 2001 From: rr- Date: Sat, 4 Feb 2017 01:08:12 +0100 Subject: [PATCH 028/159] server: refactor + add type hinting - Added type hinting (for now, 3.5-compatible) - Split `db` namespace into `db` module and `model` namespace - Changed elastic search to be created lazily for each operation - Changed to class based approach in entity serialization to allow stronger typing - Removed `required` argument from `context.get_*` family of functions; now it's implied if `default` argument is omitted - Changed `unalias_dict` implementation to use less magic inputs --- server/migrate-v1 | 6 +- server/mypy.ini | 14 + server/szurubooru/api/comment_api.py | 75 +-- server/szurubooru/api/info_api.py | 24 +- server/szurubooru/api/password_reset_api.py | 18 +- server/szurubooru/api/post_api.py | 145 ++--- server/szurubooru/api/snapshot_api.py | 9 +- server/szurubooru/api/tag_api.py | 77 +-- server/szurubooru/api/tag_category_api.py | 44 +- server/szurubooru/api/upload_api.py | 10 +- server/szurubooru/api/user_api.py | 45 +- server/szurubooru/config.py | 5 +- server/szurubooru/db.py | 36 ++ server/szurubooru/db/__init__.py | 17 - server/szurubooru/db/session.py | 27 - server/szurubooru/db/util.py | 34 -- server/szurubooru/errors.py | 8 +- server/szurubooru/facade.py | 32 +- server/szurubooru/func/auth.py | 30 +- server/szurubooru/func/cache.py | 26 +- server/szurubooru/func/comments.py | 98 ++-- server/szurubooru/func/diff.py | 17 +- server/szurubooru/func/favorites.py | 24 +- server/szurubooru/func/file_uploads.py | 17 +- server/szurubooru/func/files.py | 19 +- server/szurubooru/func/image_hash.py | 106 ++-- server/szurubooru/func/images.py | 19 +- server/szurubooru/func/mailer.py | 2 +- server/szurubooru/func/mime.py | 13 +- server/szurubooru/func/net.py | 2 +- server/szurubooru/func/posts.py | 452 ++++++++++------ server/szurubooru/func/scores.py | 22 +- server/szurubooru/func/serialization.py | 27 + server/szurubooru/func/snapshots.py | 59 +- server/szurubooru/func/tag_categories.py | 109 ++-- server/szurubooru/func/tags.py | 239 +++++---- server/szurubooru/func/users.py | 194 ++++--- server/szurubooru/func/util.py | 74 ++- server/szurubooru/func/versions.py | 11 +- server/szurubooru/middleware/authenticator.py | 27 +- server/szurubooru/middleware/cache_purger.py | 3 +- .../szurubooru/middleware/request_logger.py | 6 +- server/szurubooru/migrations/env.py | 4 +- server/szurubooru/model/__init__.py | 15 + server/szurubooru/{db => model}/base.py | 0 server/szurubooru/{db => model}/comment.py | 15 +- server/szurubooru/{db => model}/post.py | 15 +- server/szurubooru/{db => model}/snapshot.py | 2 +- server/szurubooru/{db => model}/tag.py | 10 +- .../szurubooru/{db => model}/tag_category.py | 7 +- server/szurubooru/{db => model}/user.py | 55 +- server/szurubooru/model/util.py | 42 ++ server/szurubooru/rest/__init__.py | 2 +- server/szurubooru/rest/app.py | 19 +- server/szurubooru/rest/context.py | 187 ++++--- server/szurubooru/rest/errors.py | 18 +- server/szurubooru/rest/middleware.py | 12 +- server/szurubooru/rest/routes.py | 22 +- .../search/configs/base_search_config.py | 29 +- .../search/configs/comment_search_config.py | 71 +-- .../search/configs/post_search_config.py | 504 ++++++++++++------ .../search/configs/snapshot_search_config.py | 41 +- .../search/configs/tag_search_config.py | 177 +++--- .../search/configs/user_search_config.py | 62 ++- server/szurubooru/search/configs/util.py | 91 ++-- server/szurubooru/search/criteria.py | 32 +- server/szurubooru/search/executor.py | 114 ++-- server/szurubooru/search/parser.py | 30 +- server/szurubooru/search/query.py | 16 + server/szurubooru/search/tokens.py | 21 +- server/szurubooru/search/typing.py | 6 + .../tests/api/test_comment_creating.py | 17 +- .../tests/api/test_comment_deleting.py | 22 +- .../tests/api/test_comment_rating.py | 39 +- .../tests/api/test_comment_retrieving.py | 16 +- .../tests/api/test_comment_updating.py | 18 +- .../tests/api/test_password_reset.py | 12 +- .../tests/api/test_post_creating.py | 50 +- .../tests/api/test_post_deleting.py | 14 +- .../tests/api/test_post_favoriting.py | 27 +- .../tests/api/test_post_featuring.py | 26 +- .../szurubooru/tests/api/test_post_merging.py | 17 +- .../szurubooru/tests/api/test_post_rating.py | 26 +- .../tests/api/test_post_retrieving.py | 22 +- .../tests/api/test_post_updating.py | 34 +- .../tests/api/test_snapshot_retrieving.py | 10 +- .../tests/api/test_tag_category_creating.py | 10 +- .../tests/api/test_tag_category_deleting.py | 22 +- .../tests/api/test_tag_category_retrieving.py | 14 +- .../tests/api/test_tag_category_updating.py | 18 +- .../szurubooru/tests/api/test_tag_creating.py | 12 +- .../szurubooru/tests/api/test_tag_deleting.py | 18 +- .../szurubooru/tests/api/test_tag_merging.py | 14 +- .../tests/api/test_tag_retrieving.py | 16 +- .../tests/api/test_tag_siblings_retrieving.py | 10 +- .../szurubooru/tests/api/test_tag_updating.py | 35 +- .../tests/api/test_user_creating.py | 10 +- .../tests/api/test_user_deleting.py | 24 +- .../tests/api/test_user_retrieving.py | 26 +- .../tests/api/test_user_updating.py | 36 +- server/szurubooru/tests/conftest.py | 46 +- server/szurubooru/tests/func/test_comments.py | 4 - .../szurubooru/tests/func/test_image_hash.py | 8 +- server/szurubooru/tests/func/test_posts.py | 80 ++- .../szurubooru/tests/func/test_snapshots.py | 32 +- .../tests/func/test_tag_categories.py | 4 +- server/szurubooru/tests/func/test_tags.py | 6 +- server/szurubooru/tests/func/test_users.py | 34 +- .../tests/{db => model}/__init__.py | 0 .../tests/{db => model}/test_comment.py | 18 +- .../tests/{db => model}/test_post.py | 52 +- .../tests/{db => model}/test_tag.py | 28 +- .../tests/{db => model}/test_user.py | 72 +-- server/szurubooru/tests/rest/test_context.py | 36 +- .../search/configs/test_post_search_config.py | 28 +- server/test | 1 + 116 files changed, 2868 insertions(+), 2037 deletions(-) create mode 100644 server/mypy.ini create mode 100644 server/szurubooru/db.py delete mode 100644 server/szurubooru/db/__init__.py delete mode 100644 server/szurubooru/db/session.py delete mode 100644 server/szurubooru/db/util.py create mode 100644 server/szurubooru/func/serialization.py create mode 100644 server/szurubooru/model/__init__.py rename server/szurubooru/{db => model}/base.py (100%) rename server/szurubooru/{db => model}/comment.py (84%) rename server/szurubooru/{db => model}/post.py (95%) rename server/szurubooru/{db => model}/snapshot.py (96%) rename server/szurubooru/{db => model}/tag.py (93%) rename server/szurubooru/{db => model}/tag_category.py (84%) rename server/szurubooru/{db => model}/user.py (50%) create mode 100644 server/szurubooru/model/util.py create mode 100644 server/szurubooru/search/query.py create mode 100644 server/szurubooru/search/typing.py rename server/szurubooru/tests/{db => model}/__init__.py (100%) rename server/szurubooru/tests/{db => model}/test_comment.py (74%) rename server/szurubooru/tests/{db => model}/test_post.py (71%) rename server/szurubooru/tests/{db => model}/test_tag.py (80%) rename server/szurubooru/tests/{db => model}/test_user.py (66%) diff --git a/server/migrate-v1 b/server/migrate-v1 index 0fdf9e4f..d3ec0dda 100755 --- a/server/migrate-v1 +++ b/server/migrate-v1 @@ -8,7 +8,7 @@ import zlib import concurrent.futures import logging import coloredlogs -import sqlalchemy +import sqlalchemy as sa from szurubooru import config, db from szurubooru.func import files, images, posts, comments @@ -42,8 +42,8 @@ def get_v1_session(args): port=args.port, name=args.name) logger.info('Connecting to %r...', dsn) - engine = sqlalchemy.create_engine(dsn) - session_maker = sqlalchemy.orm.sessionmaker(bind=engine) + engine = sa.create_engine(dsn) + session_maker = sa.orm.sessionmaker(bind=engine) return session_maker() def parse_args(): diff --git a/server/mypy.ini b/server/mypy.ini new file mode 100644 index 00000000..a0300b7a --- /dev/null +++ b/server/mypy.ini @@ -0,0 +1,14 @@ +[mypy] +ignore_missing_imports = True +follow_imports = skip +disallow_untyped_calls = True +disallow_untyped_defs = True +check_untyped_defs = True +disallow_subclassing_any = False +warn_redundant_casts = True +warn_unused_ignores = True +strict_optional = True +strict_boolean = False + +[mypy-szurubooru.tests.*] +ignore_errors=True diff --git a/server/szurubooru/api/comment_api.py b/server/szurubooru/api/comment_api.py index 51058160..1cde5385 100644 --- a/server/szurubooru/api/comment_api.py +++ b/server/szurubooru/api/comment_api.py @@ -1,31 +1,44 @@ -import datetime -from szurubooru import search -from szurubooru.rest import routes -from szurubooru.func import auth, comments, posts, scores, util, versions +from typing import Dict +from datetime import datetime +from szurubooru import search, rest, model +from szurubooru.func import ( + auth, comments, posts, scores, versions, serialization) _search_executor = search.Executor(search.configs.CommentSearchConfig()) -def _serialize(ctx, comment, **kwargs): +def _get_comment(params: Dict[str, str]) -> model.Comment: + try: + comment_id = int(params['comment_id']) + except TypeError: + raise comments.InvalidCommentIdError( + 'Invalid comment ID: %r.' % params['comment_id']) + return comments.get_comment_by_id(comment_id) + + +def _serialize( + ctx: rest.Context, comment: model.Comment) -> rest.Response: return comments.serialize_comment( comment, ctx.user, - options=util.get_serialization_options(ctx), **kwargs) + options=serialization.get_serialization_options(ctx)) -@routes.get('/comments/?') -def get_comments(ctx, _params=None): +@rest.routes.get('/comments/?') +def get_comments( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:list') return _search_executor.execute_and_serialize( ctx, lambda comment: _serialize(ctx, comment)) -@routes.post('/comments/?') -def create_comment(ctx, _params=None): +@rest.routes.post('/comments/?') +def create_comment( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:create') - text = ctx.get_param_as_string('text', required=True) - post_id = ctx.get_param_as_int('postId', required=True) + text = ctx.get_param_as_string('text') + post_id = ctx.get_param_as_int('postId') post = posts.get_post_by_id(post_id) comment = comments.create_comment(ctx.user, post, text) ctx.session.add(comment) @@ -33,30 +46,30 @@ def create_comment(ctx, _params=None): return _serialize(ctx, comment) -@routes.get('/comment/(?P[^/]+)/?') -def get_comment(ctx, params): +@rest.routes.get('/comment/(?P[^/]+)/?') +def get_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:view') - comment = comments.get_comment_by_id(params['comment_id']) + comment = _get_comment(params) return _serialize(ctx, comment) -@routes.put('/comment/(?P[^/]+)/?') -def update_comment(ctx, params): - comment = comments.get_comment_by_id(params['comment_id']) +@rest.routes.put('/comment/(?P[^/]+)/?') +def update_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + comment = _get_comment(params) versions.verify_version(comment, ctx) versions.bump_version(comment) infix = 'own' if ctx.user.user_id == comment.user_id else 'any' - text = ctx.get_param_as_string('text', required=True) + text = ctx.get_param_as_string('text') auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix) comments.update_comment_text(comment, text) - comment.last_edit_time = datetime.datetime.utcnow() + comment.last_edit_time = datetime.utcnow() ctx.session.commit() return _serialize(ctx, comment) -@routes.delete('/comment/(?P[^/]+)/?') -def delete_comment(ctx, params): - comment = comments.get_comment_by_id(params['comment_id']) +@rest.routes.delete('/comment/(?P[^/]+)/?') +def delete_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + comment = _get_comment(params) versions.verify_version(comment, ctx) infix = 'own' if ctx.user.user_id == comment.user_id else 'any' auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix) @@ -65,20 +78,22 @@ def delete_comment(ctx, params): return {} -@routes.put('/comment/(?P[^/]+)/score/?') -def set_comment_score(ctx, params): +@rest.routes.put('/comment/(?P[^/]+)/score/?') +def set_comment_score( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:score') - score = ctx.get_param_as_int('score', required=True) - comment = comments.get_comment_by_id(params['comment_id']) + score = ctx.get_param_as_int('score') + comment = _get_comment(params) scores.set_score(comment, ctx.user, score) ctx.session.commit() return _serialize(ctx, comment) -@routes.delete('/comment/(?P[^/]+)/score/?') -def delete_comment_score(ctx, params): +@rest.routes.delete('/comment/(?P[^/]+)/score/?') +def delete_comment_score( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:score') - comment = comments.get_comment_by_id(params['comment_id']) + comment = _get_comment(params) scores.delete_score(comment, ctx.user) ctx.session.commit() return _serialize(ctx, comment) diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index c0d2a955..e0fafedd 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -1,19 +1,20 @@ -import datetime import os -from szurubooru import config -from szurubooru.rest import routes +from typing import Optional, Dict +from datetime import datetime, timedelta +from szurubooru import config, rest from szurubooru.func import posts, users, util -_cache_time = None -_cache_result = None +_cache_time = None # type: Optional[datetime] +_cache_result = None # type: Optional[int] -def _get_disk_usage(): +def _get_disk_usage() -> int: global _cache_time, _cache_result # pylint: disable=global-statement - threshold = datetime.timedelta(hours=48) - now = datetime.datetime.utcnow() + threshold = timedelta(hours=48) + now = datetime.utcnow() if _cache_time and _cache_time > now - threshold: + assert _cache_result return _cache_result total_size = 0 for dir_path, _, file_names in os.walk(config.config['data_dir']): @@ -25,8 +26,9 @@ def _get_disk_usage(): return total_size -@routes.get('/info/?') -def get_info(ctx, _params=None): +@rest.routes.get('/info/?') +def get_info( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: post_feature = posts.try_get_current_post_feature() return { 'postCount': posts.get_post_count(), @@ -38,7 +40,7 @@ def get_info(ctx, _params=None): 'featuringUser': users.serialize_user(post_feature.user, ctx.user) if post_feature else None, - 'serverTime': datetime.datetime.utcnow(), + 'serverTime': datetime.utcnow(), 'config': { 'userNameRegex': config.config['user_name_regex'], 'passwordRegex': config.config['password_regex'], diff --git a/server/szurubooru/api/password_reset_api.py b/server/szurubooru/api/password_reset_api.py index 7e5864c9..f49080a9 100644 --- a/server/szurubooru/api/password_reset_api.py +++ b/server/szurubooru/api/password_reset_api.py @@ -1,5 +1,5 @@ -from szurubooru import config, errors -from szurubooru.rest import routes +from typing import Dict +from szurubooru import config, errors, rest from szurubooru.func import auth, mailer, users, versions @@ -10,9 +10,9 @@ MAIL_BODY = \ 'Otherwise, please ignore this email.' -@routes.get('/password-reset/(?P[^/]+)/?') -def start_password_reset(_ctx, params): - ''' Send a mail with secure token to the correlated user. ''' +@rest.routes.get('/password-reset/(?P[^/]+)/?') +def start_password_reset( + _ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user_name = params['user_name'] user = users.get_user_by_name_or_email(user_name) if not user.email: @@ -30,13 +30,13 @@ def start_password_reset(_ctx, params): return {} -@routes.post('/password-reset/(?P[^/]+)/?') -def finish_password_reset(ctx, params): - ''' Verify token from mail, generate a new password and return it. ''' +@rest.routes.post('/password-reset/(?P[^/]+)/?') +def finish_password_reset( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user_name = params['user_name'] user = users.get_user_by_name_or_email(user_name) good_token = auth.generate_authentication_token(user) - token = ctx.get_param_as_string('token', required=True) + token = ctx.get_param_as_string('token') if token != good_token: raise errors.ValidationError('Invalid password reset token.') new_password = users.reset_user_password(user) diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index cbf8f27e..6c76688f 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -1,44 +1,60 @@ -import datetime -from szurubooru import search, db, errors -from szurubooru.rest import routes +from typing import Optional, Dict +from datetime import datetime +from szurubooru import db, model, errors, rest, search from szurubooru.func import ( - auth, tags, posts, snapshots, favorites, scores, util, versions) + auth, tags, posts, snapshots, favorites, scores, serialization, versions) -_search_executor = search.Executor(search.configs.PostSearchConfig()) +_search_executor_config = search.configs.PostSearchConfig() +_search_executor = search.Executor(_search_executor_config) -def _serialize_post(ctx, post): +def _get_post_id(params: Dict[str, str]) -> int: + try: + return int(params['post_id']) + except TypeError: + raise posts.InvalidPostIdError( + 'Invalid post ID: %r.' % params['post_id']) + + +def _get_post(params: Dict[str, str]) -> model.Post: + return posts.get_post_by_id(_get_post_id(params)) + + +def _serialize_post( + ctx: rest.Context, post: Optional[model.Post]) -> rest.Response: return posts.serialize_post( post, ctx.user, - options=util.get_serialization_options(ctx)) + options=serialization.get_serialization_options(ctx)) -@routes.get('/posts/?') -def get_posts(ctx, _params=None): +@rest.routes.get('/posts/?') +def get_posts( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:list') - _search_executor.config.user = ctx.user + _search_executor_config.user = ctx.user return _search_executor.execute_and_serialize( ctx, lambda post: _serialize_post(ctx, post)) -@routes.post('/posts/?') -def create_post(ctx, _params=None): +@rest.routes.post('/posts/?') +def create_post( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: anonymous = ctx.get_param_as_bool('anonymous', default=False) if anonymous: auth.verify_privilege(ctx.user, 'posts:create:anonymous') else: auth.verify_privilege(ctx.user, 'posts:create:identified') - content = ctx.get_file('content', required=True) - tag_names = ctx.get_param_as_list('tags', required=False, default=[]) - safety = ctx.get_param_as_string('safety', required=True) - source = ctx.get_param_as_string('source', required=False, default=None) + content = ctx.get_file('content') + tag_names = ctx.get_param_as_list('tags', default=[]) + safety = ctx.get_param_as_string('safety') + source = ctx.get_param_as_string('source', default='') if ctx.has_param('contentUrl') and not source: - source = ctx.get_param_as_string('contentUrl') - relations = ctx.get_param_as_list('relations', required=False) or [] - notes = ctx.get_param_as_list('notes', required=False) or [] - flags = ctx.get_param_as_list('flags', required=False) or [] + source = ctx.get_param_as_string('contentUrl', default='') + relations = ctx.get_param_as_list('relations', default=[]) + notes = ctx.get_param_as_list('notes', default=[]) + flags = ctx.get_param_as_list('flags', default=[]) post, new_tags = posts.create_post( content, tag_names, None if anonymous else ctx.user) @@ -61,16 +77,16 @@ def create_post(ctx, _params=None): return _serialize_post(ctx, post) -@routes.get('/post/(?P[^/]+)/?') -def get_post(ctx, params): +@rest.routes.get('/post/(?P[^/]+)/?') +def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:view') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) return _serialize_post(ctx, post) -@routes.put('/post/(?P[^/]+)/?') -def update_post(ctx, params): - post = posts.get_post_by_id(params['post_id']) +@rest.routes.put('/post/(?P[^/]+)/?') +def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + post = _get_post(params) versions.verify_version(post, ctx) versions.bump_version(post) if ctx.has_file('content'): @@ -104,7 +120,7 @@ def update_post(ctx, params): if ctx.has_file('thumbnail'): auth.verify_privilege(ctx.user, 'posts:edit:thumbnail') posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) - post.last_edit_time = datetime.datetime.utcnow() + post.last_edit_time = datetime.utcnow() ctx.session.flush() snapshots.modify(post, ctx.user) ctx.session.commit() @@ -112,10 +128,10 @@ def update_post(ctx, params): return _serialize_post(ctx, post) -@routes.delete('/post/(?P[^/]+)/?') -def delete_post(ctx, params): +@rest.routes.delete('/post/(?P[^/]+)/?') +def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:delete') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) versions.verify_version(post, ctx) snapshots.delete(post, ctx.user) posts.delete(post) @@ -124,13 +140,14 @@ def delete_post(ctx, params): return {} -@routes.post('/post-merge/?') -def merge_posts(ctx, _params=None): - source_post_id = ctx.get_param_as_string('remove', required=True) or '' - target_post_id = ctx.get_param_as_string('mergeTo', required=True) or '' - replace_content = ctx.get_param_as_bool('replaceContent') +@rest.routes.post('/post-merge/?') +def merge_posts( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: + source_post_id = ctx.get_param_as_int('remove') + target_post_id = ctx.get_param_as_int('mergeTo') source_post = posts.get_post_by_id(source_post_id) target_post = posts.get_post_by_id(target_post_id) + replace_content = ctx.get_param_as_bool('replaceContent') versions.verify_version(source_post, ctx, 'removeVersion') versions.verify_version(target_post, ctx, 'mergeToVersion') versions.bump_version(target_post) @@ -141,16 +158,18 @@ def merge_posts(ctx, _params=None): return _serialize_post(ctx, target_post) -@routes.get('/featured-post/?') -def get_featured_post(ctx, _params=None): +@rest.routes.get('/featured-post/?') +def get_featured_post( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: post = posts.try_get_featured_post() return _serialize_post(ctx, post) -@routes.post('/featured-post/?') -def set_featured_post(ctx, _params=None): +@rest.routes.post('/featured-post/?') +def set_featured_post( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:feature') - post_id = ctx.get_param_as_int('id', required=True) + post_id = ctx.get_param_as_int('id') post = posts.get_post_by_id(post_id) featured_post = posts.try_get_featured_post() if featured_post and featured_post.post_id == post.post_id: @@ -162,55 +181,61 @@ def set_featured_post(ctx, _params=None): return _serialize_post(ctx, post) -@routes.put('/post/(?P[^/]+)/score/?') -def set_post_score(ctx, params): +@rest.routes.put('/post/(?P[^/]+)/score/?') +def set_post_score(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:score') - post = posts.get_post_by_id(params['post_id']) - score = ctx.get_param_as_int('score', required=True) + post = _get_post(params) + score = ctx.get_param_as_int('score') scores.set_score(post, ctx.user, score) ctx.session.commit() return _serialize_post(ctx, post) -@routes.delete('/post/(?P[^/]+)/score/?') -def delete_post_score(ctx, params): +@rest.routes.delete('/post/(?P[^/]+)/score/?') +def delete_post_score( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:score') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) scores.delete_score(post, ctx.user) ctx.session.commit() return _serialize_post(ctx, post) -@routes.post('/post/(?P[^/]+)/favorite/?') -def add_post_to_favorites(ctx, params): +@rest.routes.post('/post/(?P[^/]+)/favorite/?') +def add_post_to_favorites( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:favorite') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) favorites.set_favorite(post, ctx.user) ctx.session.commit() return _serialize_post(ctx, post) -@routes.delete('/post/(?P[^/]+)/favorite/?') -def delete_post_from_favorites(ctx, params): +@rest.routes.delete('/post/(?P[^/]+)/favorite/?') +def delete_post_from_favorites( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:favorite') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) favorites.unset_favorite(post, ctx.user) ctx.session.commit() return _serialize_post(ctx, post) -@routes.get('/post/(?P[^/]+)/around/?') -def get_posts_around(ctx, params): +@rest.routes.get('/post/(?P[^/]+)/around/?') +def get_posts_around( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:list') - _search_executor.config.user = ctx.user + _search_executor_config.user = ctx.user + post_id = _get_post_id(params) return _search_executor.get_around_and_serialize( - ctx, params['post_id'], lambda post: _serialize_post(ctx, post)) + ctx, post_id, lambda post: _serialize_post(ctx, post)) -@routes.post('/posts/reverse-search/?') -def get_posts_by_image(ctx, _params=None): +@rest.routes.post('/posts/reverse-search/?') +def get_posts_by_image( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:reverse_search') - content = ctx.get_file('content', required=True) + content = ctx.get_file('content') try: lookalikes = posts.search_by_image(content) diff --git a/server/szurubooru/api/snapshot_api.py b/server/szurubooru/api/snapshot_api.py index 009d8c97..cdcee74a 100644 --- a/server/szurubooru/api/snapshot_api.py +++ b/server/szurubooru/api/snapshot_api.py @@ -1,13 +1,14 @@ -from szurubooru import search -from szurubooru.rest import routes +from typing import Dict +from szurubooru import search, rest from szurubooru.func import auth, snapshots _search_executor = search.Executor(search.configs.SnapshotSearchConfig()) -@routes.get('/snapshots/?') -def get_snapshots(ctx, _params=None): +@rest.routes.get('/snapshots/?') +def get_snapshots( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'snapshots:list') return _search_executor.execute_and_serialize( ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user)) diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index a69673ad..7a379b3c 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -1,18 +1,22 @@ -import datetime -from szurubooru import db, search -from szurubooru.rest import routes -from szurubooru.func import auth, tags, snapshots, util, versions +from typing import Optional, List, Dict +from datetime import datetime +from szurubooru import db, model, search, rest +from szurubooru.func import auth, tags, snapshots, serialization, versions _search_executor = search.Executor(search.configs.TagSearchConfig()) -def _serialize(ctx, tag): +def _serialize(ctx: rest.Context, tag: model.Tag) -> rest.Response: return tags.serialize_tag( - tag, options=util.get_serialization_options(ctx)) + tag, options=serialization.get_serialization_options(ctx)) -def _create_if_needed(tag_names, user): +def _get_tag(params: Dict[str, str]) -> model.Tag: + return tags.get_tag_by_name(params['tag_name']) + + +def _create_if_needed(tag_names: List[str], user: model.User) -> None: if not tag_names: return _existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) @@ -23,25 +27,22 @@ def _create_if_needed(tag_names, user): snapshots.create(tag, user) -@routes.get('/tags/?') -def get_tags(ctx, _params=None): +@rest.routes.get('/tags/?') +def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'tags:list') return _search_executor.execute_and_serialize( ctx, lambda tag: _serialize(ctx, tag)) -@routes.post('/tags/?') -def create_tag(ctx, _params=None): +@rest.routes.post('/tags/?') +def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'tags:create') - names = ctx.get_param_as_list('names', required=True) - category = ctx.get_param_as_string('category', required=True) - description = ctx.get_param_as_string( - 'description', required=False, default=None) - suggestions = ctx.get_param_as_list( - 'suggestions', required=False, default=[]) - implications = ctx.get_param_as_list( - 'implications', required=False, default=[]) + names = ctx.get_param_as_list('names') + category = ctx.get_param_as_string('category') + description = ctx.get_param_as_string('description', default='') + suggestions = ctx.get_param_as_list('suggestions', default=[]) + implications = ctx.get_param_as_list('implications', default=[]) _create_if_needed(suggestions, ctx.user) _create_if_needed(implications, ctx.user) @@ -56,16 +57,16 @@ def create_tag(ctx, _params=None): return _serialize(ctx, tag) -@routes.get('/tag/(?P.+)') -def get_tag(ctx, params): +@rest.routes.get('/tag/(?P.+)') +def get_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'tags:view') - tag = tags.get_tag_by_name(params['tag_name']) + tag = _get_tag(params) return _serialize(ctx, tag) -@routes.put('/tag/(?P.+)') -def update_tag(ctx, params): - tag = tags.get_tag_by_name(params['tag_name']) +@rest.routes.put('/tag/(?P.+)') +def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + tag = _get_tag(params) versions.verify_version(tag, ctx) versions.bump_version(tag) if ctx.has_param('names'): @@ -78,7 +79,7 @@ def update_tag(ctx, params): if ctx.has_param('description'): auth.verify_privilege(ctx.user, 'tags:edit:description') tags.update_tag_description( - tag, ctx.get_param_as_string('description', default=None)) + tag, ctx.get_param_as_string('description')) if ctx.has_param('suggestions'): auth.verify_privilege(ctx.user, 'tags:edit:suggestions') suggestions = ctx.get_param_as_list('suggestions') @@ -89,7 +90,7 @@ def update_tag(ctx, params): implications = ctx.get_param_as_list('implications') _create_if_needed(implications, ctx.user) tags.update_tag_implications(tag, implications) - tag.last_edit_time = datetime.datetime.utcnow() + tag.last_edit_time = datetime.utcnow() ctx.session.flush() snapshots.modify(tag, ctx.user) ctx.session.commit() @@ -97,9 +98,9 @@ def update_tag(ctx, params): return _serialize(ctx, tag) -@routes.delete('/tag/(?P.+)') -def delete_tag(ctx, params): - tag = tags.get_tag_by_name(params['tag_name']) +@rest.routes.delete('/tag/(?P.+)') +def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + tag = _get_tag(params) versions.verify_version(tag, ctx) auth.verify_privilege(ctx.user, 'tags:delete') snapshots.delete(tag, ctx.user) @@ -109,10 +110,11 @@ def delete_tag(ctx, params): return {} -@routes.post('/tag-merge/?') -def merge_tags(ctx, _params=None): - source_tag_name = ctx.get_param_as_string('remove', required=True) or '' - target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or '' +@rest.routes.post('/tag-merge/?') +def merge_tags( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: + source_tag_name = ctx.get_param_as_string('remove') + target_tag_name = ctx.get_param_as_string('mergeTo') source_tag = tags.get_tag_by_name(source_tag_name) target_tag = tags.get_tag_by_name(target_tag_name) versions.verify_version(source_tag, ctx, 'removeVersion') @@ -126,10 +128,11 @@ def merge_tags(ctx, _params=None): return _serialize(ctx, target_tag) -@routes.get('/tag-siblings/(?P.+)') -def get_tag_siblings(ctx, params): +@rest.routes.get('/tag-siblings/(?P.+)') +def get_tag_siblings( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'tags:view') - tag = tags.get_tag_by_name(params['tag_name']) + tag = _get_tag(params) result = tags.get_tag_siblings(tag) serialized_siblings = [] for sibling, occurrences in result: diff --git a/server/szurubooru/api/tag_category_api.py b/server/szurubooru/api/tag_category_api.py index 139da1d8..c7aaca89 100644 --- a/server/szurubooru/api/tag_category_api.py +++ b/server/szurubooru/api/tag_category_api.py @@ -1,15 +1,18 @@ -from szurubooru.rest import routes +from typing import Dict +from szurubooru import model, rest from szurubooru.func import ( - auth, tags, tag_categories, snapshots, util, versions) + auth, tags, tag_categories, snapshots, serialization, versions) -def _serialize(ctx, category): +def _serialize( + ctx: rest.Context, category: model.TagCategory) -> rest.Response: return tag_categories.serialize_category( - category, options=util.get_serialization_options(ctx)) + category, options=serialization.get_serialization_options(ctx)) -@routes.get('/tag-categories/?') -def get_tag_categories(ctx, _params=None): +@rest.routes.get('/tag-categories/?') +def get_tag_categories( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'tag_categories:list') categories = tag_categories.get_all_categories() return { @@ -17,11 +20,12 @@ def get_tag_categories(ctx, _params=None): } -@routes.post('/tag-categories/?') -def create_tag_category(ctx, _params=None): +@rest.routes.post('/tag-categories/?') +def create_tag_category( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'tag_categories:create') - name = ctx.get_param_as_string('name', required=True) - color = ctx.get_param_as_string('color', required=True) + name = ctx.get_param_as_string('name') + color = ctx.get_param_as_string('color') category = tag_categories.create_category(name, color) ctx.session.add(category) ctx.session.flush() @@ -31,15 +35,17 @@ def create_tag_category(ctx, _params=None): return _serialize(ctx, category) -@routes.get('/tag-category/(?P[^/]+)/?') -def get_tag_category(ctx, params): +@rest.routes.get('/tag-category/(?P[^/]+)/?') +def get_tag_category( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'tag_categories:view') category = tag_categories.get_category_by_name(params['category_name']) return _serialize(ctx, category) -@routes.put('/tag-category/(?P[^/]+)/?') -def update_tag_category(ctx, params): +@rest.routes.put('/tag-category/(?P[^/]+)/?') +def update_tag_category( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: category = tag_categories.get_category_by_name( params['category_name'], lock=True) versions.verify_version(category, ctx) @@ -59,8 +65,9 @@ def update_tag_category(ctx, params): return _serialize(ctx, category) -@routes.delete('/tag-category/(?P[^/]+)/?') -def delete_tag_category(ctx, params): +@rest.routes.delete('/tag-category/(?P[^/]+)/?') +def delete_tag_category( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: category = tag_categories.get_category_by_name( params['category_name'], lock=True) versions.verify_version(category, ctx) @@ -72,8 +79,9 @@ def delete_tag_category(ctx, params): return {} -@routes.put('/tag-category/(?P[^/]+)/default/?') -def set_tag_category_as_default(ctx, params): +@rest.routes.put('/tag-category/(?P[^/]+)/default/?') +def set_tag_category_as_default( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'tag_categories:set_default') category = tag_categories.get_category_by_name( params['category_name'], lock=True) diff --git a/server/szurubooru/api/upload_api.py b/server/szurubooru/api/upload_api.py index eaf2880b..9200eaa0 100644 --- a/server/szurubooru/api/upload_api.py +++ b/server/szurubooru/api/upload_api.py @@ -1,10 +1,12 @@ -from szurubooru.rest import routes +from typing import Dict +from szurubooru import rest from szurubooru.func import auth, file_uploads -@routes.post('/uploads/?') -def create_temporary_file(ctx, _params=None): +@rest.routes.post('/uploads/?') +def create_temporary_file( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'uploads:create') - content = ctx.get_file('content', required=True, allow_tokens=False) + content = ctx.get_file('content', allow_tokens=False) token = file_uploads.save(content) return {'token': token} diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index 187e4686..910f2a42 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -1,56 +1,57 @@ -from szurubooru import search -from szurubooru.rest import routes -from szurubooru.func import auth, users, util, versions +from typing import Any, Dict +from szurubooru import model, search, rest +from szurubooru.func import auth, users, serialization, versions _search_executor = search.Executor(search.configs.UserSearchConfig()) -def _serialize(ctx, user, **kwargs): +def _serialize( + ctx: rest.Context, user: model.User, **kwargs: Any) -> rest.Response: return users.serialize_user( user, ctx.user, - options=util.get_serialization_options(ctx), + options=serialization.get_serialization_options(ctx), **kwargs) -@routes.get('/users/?') -def get_users(ctx, _params=None): +@rest.routes.get('/users/?') +def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'users:list') return _search_executor.execute_and_serialize( ctx, lambda user: _serialize(ctx, user)) -@routes.post('/users/?') -def create_user(ctx, _params=None): +@rest.routes.post('/users/?') +def create_user( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'users:create') - name = ctx.get_param_as_string('name', required=True) - password = ctx.get_param_as_string('password', required=True) - email = ctx.get_param_as_string('email', required=False, default='') + name = ctx.get_param_as_string('name') + password = ctx.get_param_as_string('password') + email = ctx.get_param_as_string('email', default='') user = users.create_user(name, password, email) if ctx.has_param('rank'): - users.update_user_rank( - user, ctx.get_param_as_string('rank'), ctx.user) + users.update_user_rank(user, ctx.get_param_as_string('rank'), ctx.user) if ctx.has_param('avatarStyle'): users.update_user_avatar( user, ctx.get_param_as_string('avatarStyle'), - ctx.get_file('avatar')) + ctx.get_file('avatar', default=b'')) ctx.session.add(user) ctx.session.commit() return _serialize(ctx, user, force_show_email=True) -@routes.get('/user/(?P[^/]+)/?') -def get_user(ctx, params): +@rest.routes.get('/user/(?P[^/]+)/?') +def get_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user = users.get_user_by_name(params['user_name']) if ctx.user.user_id != user.user_id: auth.verify_privilege(ctx.user, 'users:view') return _serialize(ctx, user) -@routes.put('/user/(?P[^/]+)/?') -def update_user(ctx, params): +@rest.routes.put('/user/(?P[^/]+)/?') +def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user = users.get_user_by_name(params['user_name']) versions.verify_version(user, ctx) versions.bump_version(user) @@ -74,13 +75,13 @@ def update_user(ctx, params): users.update_user_avatar( user, ctx.get_param_as_string('avatarStyle'), - ctx.get_file('avatar')) + ctx.get_file('avatar', default=b'')) ctx.session.commit() return _serialize(ctx, user) -@routes.delete('/user/(?P[^/]+)/?') -def delete_user(ctx, params): +@rest.routes.delete('/user/(?P[^/]+)/?') +def delete_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user = users.get_user_by_name(params['user_name']) versions.verify_version(user, ctx) infix = 'self' if ctx.user.user_id == user.user_id else 'any' diff --git a/server/szurubooru/config.py b/server/szurubooru/config.py index e5693117..567b4574 100644 --- a/server/szurubooru/config.py +++ b/server/szurubooru/config.py @@ -1,8 +1,9 @@ +from typing import Dict import os import yaml -def merge(left, right): +def merge(left: Dict, right: Dict) -> Dict: for key in right: if key in left: if isinstance(left[key], dict) and isinstance(right[key], dict): @@ -14,7 +15,7 @@ def merge(left, right): return left -def read_config(): +def read_config() -> Dict: with open('../config.yaml.dist') as handle: ret = yaml.load(handle.read()) if os.path.exists('../config.yaml'): diff --git a/server/szurubooru/db.py b/server/szurubooru/db.py new file mode 100644 index 00000000..f90bfaf9 --- /dev/null +++ b/server/szurubooru/db.py @@ -0,0 +1,36 @@ +from typing import Any +import threading +import sqlalchemy as sa +import sqlalchemy.orm +from szurubooru import config + +# pylint: disable=invalid-name +_data = threading.local() +_engine = sa.create_engine(config.config['database']) # type: Any +sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any +session = sa.orm.scoped_session(sessionmaker) # type: Any + + +def get_session() -> Any: + global session + return session + + +def set_sesssion(new_session: Any) -> None: + global session + session = new_session + + +def reset_query_count() -> None: + _data.query_count = 0 + + +def get_query_count() -> int: + return _data.query_count + + +def _bump_query_count() -> None: + _data.query_count = getattr(_data, 'query_count', 0) + 1 + + +sa.event.listen(_engine, 'after_execute', lambda *args: _bump_query_count()) diff --git a/server/szurubooru/db/__init__.py b/server/szurubooru/db/__init__.py deleted file mode 100644 index 3eb18833..00000000 --- a/server/szurubooru/db/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from szurubooru.db.base import Base -from szurubooru.db.user import User -from szurubooru.db.tag_category import TagCategory -from szurubooru.db.tag import (Tag, TagName, TagSuggestion, TagImplication) -from szurubooru.db.post import ( - Post, - PostTag, - PostRelation, - PostFavorite, - PostScore, - PostNote, - PostFeature) -from szurubooru.db.comment import (Comment, CommentScore) -from szurubooru.db.snapshot import Snapshot -from szurubooru.db.session import ( - session, sessionmaker, reset_query_count, get_query_count) -import szurubooru.db.util diff --git a/server/szurubooru/db/session.py b/server/szurubooru/db/session.py deleted file mode 100644 index fd77b4c2..00000000 --- a/server/szurubooru/db/session.py +++ /dev/null @@ -1,27 +0,0 @@ -import threading -import sqlalchemy -from szurubooru import config - - -# pylint: disable=invalid-name -_engine = sqlalchemy.create_engine(config.config['database']) -sessionmaker = sqlalchemy.orm.sessionmaker(bind=_engine, autoflush=False) -session = sqlalchemy.orm.scoped_session(sessionmaker) - -_data = threading.local() - - -def reset_query_count(): - _data.query_count = 0 - - -def get_query_count(): - return _data.query_count - - -def _bump_query_count(): - _data.query_count = getattr(_data, 'query_count', 0) + 1 - - -sqlalchemy.event.listen( - _engine, 'after_execute', lambda *args: _bump_query_count()) diff --git a/server/szurubooru/db/util.py b/server/szurubooru/db/util.py deleted file mode 100644 index d6edf188..00000000 --- a/server/szurubooru/db/util.py +++ /dev/null @@ -1,34 +0,0 @@ -from sqlalchemy.inspection import inspect - - -def get_resource_info(entity): - serializers = { - 'tag': lambda tag: tag.first_name, - 'tag_category': lambda category: category.name, - 'comment': lambda comment: comment.comment_id, - 'post': lambda post: post.post_id, - } - - resource_type = entity.__table__.name - assert resource_type in serializers - - primary_key = inspect(entity).identity - assert primary_key is not None - assert len(primary_key) == 1 - - resource_name = serializers[resource_type](entity) - assert resource_name - - resource_pkey = primary_key[0] - assert resource_pkey - - return (resource_type, resource_pkey, resource_name) - - -def get_aux_entity(session, get_table_info, entity, user): - table, get_column = get_table_info(entity) - return session \ - .query(table) \ - .filter(get_column(table) == get_column(entity)) \ - .filter(table.user_id == user.user_id) \ - .one_or_none() diff --git a/server/szurubooru/errors.py b/server/szurubooru/errors.py index 4fbb67b6..b5f1cc3b 100644 --- a/server/szurubooru/errors.py +++ b/server/szurubooru/errors.py @@ -1,5 +1,11 @@ +from typing import Dict + + class BaseError(RuntimeError): - def __init__(self, message='Unknown error', extra_fields=None): + def __init__( + self, + message: str='Unknown error', + extra_fields: Dict[str, str]=None) -> None: super().__init__(message) self.extra_fields = extra_fields diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index 48957a1f..f39fcf92 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -2,7 +2,10 @@ import os import time import logging import threading +from typing import Callable, Any, Type + import coloredlogs +import sqlalchemy as sa import sqlalchemy.orm.exc from szurubooru import config, db, errors, rest from szurubooru.func import posts, file_uploads @@ -10,7 +13,10 @@ from szurubooru.func import posts, file_uploads from szurubooru import api, middleware -def _map_error(ex, target_class, title): +def _map_error( + ex: Exception, + target_class: Type[rest.errors.BaseHttpError], + title: str) -> rest.errors.BaseHttpError: return target_class( name=type(ex).__name__, title=title, @@ -18,38 +24,38 @@ def _map_error(ex, target_class, title): extra_fields=getattr(ex, 'extra_fields', {})) -def _on_auth_error(ex): +def _on_auth_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpForbidden, 'Authentication error') -def _on_validation_error(ex): +def _on_validation_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpBadRequest, 'Validation error') -def _on_search_error(ex): +def _on_search_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpBadRequest, 'Search error') -def _on_integrity_error(ex): +def _on_integrity_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpConflict, 'Integrity violation') -def _on_not_found_error(ex): +def _on_not_found_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpNotFound, 'Not found') -def _on_processing_error(ex): +def _on_processing_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error') -def _on_third_party_error(ex): +def _on_third_party_error(ex: Exception) -> None: raise _map_error( ex, rest.errors.HttpInternalServerError, 'Server configuration error') -def _on_stale_data_error(_ex): +def _on_stale_data_error(_ex: Exception) -> None: raise rest.errors.HttpConflict( name='IntegrityError', title='Integrity violation', @@ -58,7 +64,7 @@ def _on_stale_data_error(_ex): 'Please try again.')) -def validate_config(): +def validate_config() -> None: ''' Check whether config doesn't contain errors that might prove lethal at runtime. @@ -86,7 +92,7 @@ def validate_config(): raise errors.ConfigError('Database is not configured') -def purge_old_uploads(): +def purge_old_uploads() -> None: while True: try: file_uploads.purge_old_uploads() @@ -95,7 +101,7 @@ def purge_old_uploads(): time.sleep(60 * 5) -def create_app(): +def create_app() -> Callable[[Any, Any], Any]: ''' Create a WSGI compatible App object. ''' validate_config() coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s') @@ -122,7 +128,7 @@ def create_app(): rest.errors.handle(errors.NotFoundError, _on_not_found_error) rest.errors.handle(errors.ProcessingError, _on_processing_error) rest.errors.handle(errors.ThirdPartyError, _on_third_party_error) - rest.errors.handle(sqlalchemy.orm.exc.StaleDataError, _on_stale_data_error) + rest.errors.handle(sa.orm.exc.StaleDataError, _on_stale_data_error) return rest.application diff --git a/server/szurubooru/func/auth.py b/server/szurubooru/func/auth.py index d71c8f9d..25c991c4 100644 --- a/server/szurubooru/func/auth.py +++ b/server/szurubooru/func/auth.py @@ -1,22 +1,22 @@ import hashlib import random from collections import OrderedDict -from szurubooru import config, db, errors +from szurubooru import config, model, errors from szurubooru.func import util RANK_MAP = OrderedDict([ - (db.User.RANK_ANONYMOUS, 'anonymous'), - (db.User.RANK_RESTRICTED, 'restricted'), - (db.User.RANK_REGULAR, 'regular'), - (db.User.RANK_POWER, 'power'), - (db.User.RANK_MODERATOR, 'moderator'), - (db.User.RANK_ADMINISTRATOR, 'administrator'), - (db.User.RANK_NOBODY, 'nobody'), + (model.User.RANK_ANONYMOUS, 'anonymous'), + (model.User.RANK_RESTRICTED, 'restricted'), + (model.User.RANK_REGULAR, 'regular'), + (model.User.RANK_POWER, 'power'), + (model.User.RANK_MODERATOR, 'moderator'), + (model.User.RANK_ADMINISTRATOR, 'administrator'), + (model.User.RANK_NOBODY, 'nobody'), ]) -def get_password_hash(salt, password): +def get_password_hash(salt: str, password: str) -> str: ''' Retrieve new-style password hash. ''' digest = hashlib.sha256() digest.update(config.config['secret'].encode('utf8')) @@ -25,7 +25,7 @@ def get_password_hash(salt, password): return digest.hexdigest() -def get_legacy_password_hash(salt, password): +def get_legacy_password_hash(salt: str, password: str) -> str: ''' Retrieve old-style password hash. ''' digest = hashlib.sha1() digest.update(b'1A2/$_4xVa') @@ -34,7 +34,7 @@ def get_legacy_password_hash(salt, password): return digest.hexdigest() -def create_password(): +def create_password() -> str: alphabet = { 'c': list('bcdfghijklmnpqrstvwxyz'), 'v': list('aeiou'), @@ -44,7 +44,7 @@ def create_password(): return ''.join(random.choice(alphabet[l]) for l in list(pattern)) -def is_valid_password(user, password): +def is_valid_password(user: model.User, password: str) -> bool: assert user salt, valid_hash = user.password_salt, user.password_hash possible_hashes = [ @@ -54,7 +54,7 @@ def is_valid_password(user, password): return valid_hash in possible_hashes -def has_privilege(user, privilege_name): +def has_privilege(user: model.User, privilege_name: str) -> bool: assert user all_ranks = list(RANK_MAP.keys()) assert privilege_name in config.config['privileges'] @@ -65,13 +65,13 @@ def has_privilege(user, privilege_name): return user.rank in good_ranks -def verify_privilege(user, privilege_name): +def verify_privilege(user: model.User, privilege_name: str) -> None: assert user if not has_privilege(user, privilege_name): raise errors.AuthError('Insufficient privileges to do this.') -def generate_authentication_token(user): +def generate_authentication_token(user: model.User) -> str: ''' Generate nonguessable challenge (e.g. links in password reminder). ''' assert user digest = hashlib.md5() diff --git a/server/szurubooru/func/cache.py b/server/szurubooru/func/cache.py index 4b775548..345835c2 100644 --- a/server/szurubooru/func/cache.py +++ b/server/szurubooru/func/cache.py @@ -1,21 +1,21 @@ +from typing import Any, List, Dict from datetime import datetime class LruCacheItem: - def __init__(self, key, value): + def __init__(self, key: object, value: Any) -> None: self.key = key self.value = value self.timestamp = datetime.utcnow() class LruCache: - def __init__(self, length, delta=None): + def __init__(self, length: int) -> None: self.length = length - self.delta = delta - self.hash = {} - self.item_list = [] + self.hash = {} # type: Dict[object, LruCacheItem] + self.item_list = [] # type: List[LruCacheItem] - def insert_item(self, item): + def insert_item(self, item: LruCacheItem) -> None: if item.key in self.hash: item_index = next( i @@ -31,11 +31,11 @@ class LruCache: self.hash[item.key] = item self.item_list.insert(0, item) - def remove_all(self): + def remove_all(self) -> None: self.hash = {} self.item_list = [] - def remove_item(self, item): + def remove_item(self, item: LruCacheItem) -> None: del self.hash[item.key] del self.item_list[self.item_list.index(item)] @@ -43,22 +43,22 @@ class LruCache: _CACHE = LruCache(length=100) -def purge(): +def purge() -> None: _CACHE.remove_all() -def has(key): +def has(key: object) -> bool: return key in _CACHE.hash -def get(key): +def get(key: object) -> Any: return _CACHE.hash[key].value -def remove(key): +def remove(key: object) -> None: if has(key): del _CACHE.hash[key] -def put(key, value): +def put(key: object, value: Any) -> None: _CACHE.insert_item(LruCacheItem(key, value)) diff --git a/server/szurubooru/func/comments.py b/server/szurubooru/func/comments.py index 6b7def85..fe15d8b6 100644 --- a/server/szurubooru/func/comments.py +++ b/server/szurubooru/func/comments.py @@ -1,6 +1,7 @@ -import datetime -from szurubooru import db, errors -from szurubooru.func import users, scores, util +from datetime import datetime +from typing import Any, Optional, List, Dict, Callable +from szurubooru import db, model, errors, rest +from szurubooru.func import users, scores, util, serialization class InvalidCommentIdError(errors.ValidationError): @@ -15,52 +16,87 @@ class EmptyCommentTextError(errors.ValidationError): pass -def serialize_comment(comment, auth_user, options=None): - return util.serialize_entity( - comment, - { - 'id': lambda: comment.comment_id, - 'user': - lambda: users.serialize_micro_user(comment.user, auth_user), - 'postId': lambda: comment.post.post_id, - 'version': lambda: comment.version, - 'text': lambda: comment.text, - 'creationTime': lambda: comment.creation_time, - 'lastEditTime': lambda: comment.last_edit_time, - 'score': lambda: comment.score, - 'ownScore': lambda: scores.get_score(comment, auth_user), - }, - options) +class CommentSerializer(serialization.BaseSerializer): + def __init__(self, comment: model.Comment, auth_user: model.User) -> None: + self.comment = comment + self.auth_user = auth_user + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'id': self.serialize_id, + 'user': self.serialize_user, + 'postId': self.serialize_post_id, + 'version': self.serialize_version, + 'text': self.serialize_text, + 'creationTime': self.serialize_creation_time, + 'lastEditTime': self.serialize_last_edit_time, + 'score': self.serialize_score, + 'ownScore': self.serialize_own_score, + } + + def serialize_id(self) -> Any: + return self.comment.comment_id + + def serialize_user(self) -> Any: + return users.serialize_micro_user(self.comment.user, self.auth_user) + + def serialize_post_id(self) -> Any: + return self.comment.post.post_id + + def serialize_version(self) -> Any: + return self.comment.version + + def serialize_text(self) -> Any: + return self.comment.text + + def serialize_creation_time(self) -> Any: + return self.comment.creation_time + + def serialize_last_edit_time(self) -> Any: + return self.comment.last_edit_time + + def serialize_score(self) -> Any: + return self.comment.score + + def serialize_own_score(self) -> Any: + return scores.get_score(self.comment, self.auth_user) -def try_get_comment_by_id(comment_id): - try: - comment_id = int(comment_id) - except ValueError: - raise InvalidCommentIdError('Invalid comment ID: %r.' % comment_id) +def serialize_comment( + comment: model.Comment, + auth_user: model.User, + options: List[str]=[]) -> rest.Response: + if comment is None: + return None + return CommentSerializer(comment, auth_user).serialize(options) + + +def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]: + comment_id = int(comment_id) return db.session \ - .query(db.Comment) \ - .filter(db.Comment.comment_id == comment_id) \ + .query(model.Comment) \ + .filter(model.Comment.comment_id == comment_id) \ .one_or_none() -def get_comment_by_id(comment_id): +def get_comment_by_id(comment_id: int) -> model.Comment: comment = try_get_comment_by_id(comment_id) if comment: return comment raise CommentNotFoundError('Comment %r not found.' % comment_id) -def create_comment(user, post, text): - comment = db.Comment() +def create_comment( + user: model.User, post: model.Post, text: str) -> model.Comment: + comment = model.Comment() comment.user = user comment.post = post update_comment_text(comment, text) - comment.creation_time = datetime.datetime.utcnow() + comment.creation_time = datetime.utcnow() return comment -def update_comment_text(comment, text): +def update_comment_text(comment: model.Comment, text: str) -> None: assert comment if not text: raise EmptyCommentTextError('Comment text cannot be empty.') diff --git a/server/szurubooru/func/diff.py b/server/szurubooru/func/diff.py index 0950f0f0..90014f7e 100644 --- a/server/szurubooru/func/diff.py +++ b/server/szurubooru/func/diff.py @@ -1,21 +1,26 @@ -def get_list_diff(old, new): - value = {'type': 'list change', 'added': [], 'removed': []} +from typing import List, Dict, Any + + +def get_list_diff(old: List[Any], new: List[Any]) -> Any: equal = True + removed = [] # type: List[Any] + added = [] # type: List[Any] for item in old: if item not in new: equal = False - value['removed'].append(item) + removed.append(item) for item in new: if item not in old: equal = False - value['added'].append(item) + added.append(item) - return None if equal else value + return None if equal else { + 'type': 'list change', 'added': added, 'removed': removed} -def get_dict_diff(old, new): +def get_dict_diff(old: Dict[str, Any], new: Dict[str, Any]) -> Any: value = {} equal = True diff --git a/server/szurubooru/func/favorites.py b/server/szurubooru/func/favorites.py index 00952de7..f567bfad 100644 --- a/server/szurubooru/func/favorites.py +++ b/server/szurubooru/func/favorites.py @@ -1,32 +1,34 @@ -import datetime -from szurubooru import db, errors +from typing import Any, Optional, Callable, Tuple +from datetime import datetime +from szurubooru import db, model, errors class InvalidFavoriteTargetError(errors.ValidationError): pass -def _get_table_info(entity): +def _get_table_info( + entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]: assert entity - resource_type, _, _ = db.util.get_resource_info(entity) + resource_type, _, _ = model.util.get_resource_info(entity) if resource_type == 'post': - return db.PostFavorite, lambda table: table.post_id + return model.PostFavorite, lambda table: table.post_id raise InvalidFavoriteTargetError() -def _get_fav_entity(entity, user): +def _get_fav_entity(entity: model.Base, user: model.User) -> model.Base: assert entity assert user - return db.util.get_aux_entity(db.session, _get_table_info, entity, user) + return model.util.get_aux_entity(db.session, _get_table_info, entity, user) -def has_favorited(entity, user): +def has_favorited(entity: model.Base, user: model.User) -> bool: assert entity assert user return _get_fav_entity(entity, user) is not None -def unset_favorite(entity, user): +def unset_favorite(entity: model.Base, user: Optional[model.User]) -> None: assert entity assert user fav_entity = _get_fav_entity(entity, user) @@ -34,7 +36,7 @@ def unset_favorite(entity, user): db.session.delete(fav_entity) -def set_favorite(entity, user): +def set_favorite(entity: model.Base, user: Optional[model.User]) -> None: from szurubooru.func import scores assert entity assert user @@ -48,5 +50,5 @@ def set_favorite(entity, user): fav_entity = table() setattr(fav_entity, get_column(table).name, get_column(entity)) fav_entity.user = user - fav_entity.time = datetime.datetime.utcnow() + fav_entity.time = datetime.utcnow() db.session.add(fav_entity) diff --git a/server/szurubooru/func/file_uploads.py b/server/szurubooru/func/file_uploads.py index 95698e36..e7f93d83 100644 --- a/server/szurubooru/func/file_uploads.py +++ b/server/szurubooru/func/file_uploads.py @@ -1,27 +1,28 @@ -import datetime +from typing import Optional +from datetime import datetime, timedelta from szurubooru.func import files, util MAX_MINUTES = 60 -def _get_path(checksum): +def _get_path(checksum: str) -> str: return 'temporary-uploads/%s.dat' % checksum -def purge_old_uploads(): - now = datetime.datetime.now() +def purge_old_uploads() -> None: + now = datetime.now() for file in files.scan('temporary-uploads'): - file_time = datetime.datetime.fromtimestamp(file.stat().st_ctime) - if now - file_time > datetime.timedelta(minutes=MAX_MINUTES): + file_time = datetime.fromtimestamp(file.stat().st_ctime) + if now - file_time > timedelta(minutes=MAX_MINUTES): files.delete('temporary-uploads/%s' % file.name) -def get(checksum): +def get(checksum: str) -> Optional[bytes]: return files.get('temporary-uploads/%s.dat' % checksum) -def save(content): +def save(content: bytes) -> str: checksum = util.get_sha1(content) path = _get_path(checksum) if not files.has(path): diff --git a/server/szurubooru/func/files.py b/server/szurubooru/func/files.py index 3ca87776..0a992ee4 100644 --- a/server/szurubooru/func/files.py +++ b/server/szurubooru/func/files.py @@ -1,32 +1,33 @@ +from typing import Any, Optional, List import os from szurubooru import config -def _get_full_path(path): +def _get_full_path(path: str) -> str: return os.path.join(config.config['data_dir'], path) -def delete(path): +def delete(path: str) -> None: full_path = _get_full_path(path) if os.path.exists(full_path): os.unlink(full_path) -def has(path): +def has(path: str) -> bool: return os.path.exists(_get_full_path(path)) -def scan(path): +def scan(path: str) -> List[os.DirEntry]: if has(path): - return os.scandir(_get_full_path(path)) + return list(os.scandir(_get_full_path(path))) return [] -def move(source_path, target_path): - return os.rename(_get_full_path(source_path), _get_full_path(target_path)) +def move(source_path: str, target_path: str) -> None: + os.rename(_get_full_path(source_path), _get_full_path(target_path)) -def get(path): +def get(path: str) -> Optional[bytes]: full_path = _get_full_path(path) if not os.path.exists(full_path): return None @@ -34,7 +35,7 @@ def get(path): return handle.read() -def save(path, content): +def save(path: str, content: bytes) -> None: full_path = _get_full_path(path) os.makedirs(os.path.dirname(full_path), exist_ok=True) with open(full_path, 'wb') as handle: diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index c89a2ec1..dc998e83 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -1,6 +1,7 @@ import logging from io import BytesIO from datetime import datetime +from typing import Any, Optional, Tuple, Set, List, Callable import elasticsearch import elasticsearch_dsl import numpy as np @@ -10,13 +11,8 @@ from szurubooru import config, errors # pylint: disable=invalid-name logger = logging.getLogger(__name__) -es = elasticsearch.Elasticsearch([{ - 'host': config.config['elasticsearch']['host'], - 'port': config.config['elasticsearch']['port'], -}]) - -# Math based on paper from H. Chi Wong, Marshall Bern and David Goldber +# Math based on paper from H. Chi Wong, Marshall Bern and David Goldberg # Math code taken from https://github.com/ascribe/image-match # (which is licensed under Apache 2 license) @@ -32,14 +28,27 @@ MAX_WORDS = 63 ES_DOC_TYPE = 'image' ES_MAX_RESULTS = 100 +Window = Tuple[Tuple[float, float], Tuple[float, float]] +NpMatrix = Any -def _preprocess_image(image_or_path): - img = Image.open(BytesIO(image_or_path)) + +def _get_session() -> elasticsearch.Elasticsearch: + return elasticsearch.Elasticsearch([{ + 'host': config.config['elasticsearch']['host'], + 'port': config.config['elasticsearch']['port'], + }]) + + +def _preprocess_image(content: bytes) -> NpMatrix: + img = Image.open(BytesIO(content)) img = img.convert('RGB') return rgb2gray(np.asarray(img, dtype=np.uint8)) -def _crop_image(image, lower_percentile, upper_percentile): +def _crop_image( + image: NpMatrix, + lower_percentile: float, + upper_percentile: float) -> Window: rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1)) cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0)) upper_column_limit = np.searchsorted( @@ -56,16 +65,19 @@ def _crop_image(image, lower_percentile, upper_percentile): if lower_column_limit > upper_column_limit: lower_column_limit = int(lower_percentile / 100. * image.shape[1]) upper_column_limit = int(upper_percentile / 100. * image.shape[1]) - return [ + return ( (lower_row_limit, upper_row_limit), - (lower_column_limit, upper_column_limit)] + (lower_column_limit, upper_column_limit)) -def _normalize_and_threshold(diff_array, identical_tolerance, n_levels): +def _normalize_and_threshold( + diff_array: NpMatrix, + identical_tolerance: float, + n_levels: int) -> None: mask = np.abs(diff_array) < identical_tolerance diff_array[mask] = 0. if np.all(mask): - return None + return positive_cutoffs = np.percentile( diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1)) negative_cutoffs = np.percentile( @@ -82,18 +94,24 @@ def _normalize_and_threshold(diff_array, identical_tolerance, n_levels): diff_array[ (diff_array <= interval[0]) & (diff_array >= interval[1])] = \ -(level + 1) - return None -def _compute_grid_points(image, n, window=None): +def _compute_grid_points( + image: NpMatrix, + n: float, + window: Window=None) -> Tuple[NpMatrix, NpMatrix]: if window is None: - window = [(0, image.shape[0]), (0, image.shape[1])] + window = ((0, image.shape[0]), (0, image.shape[1])) x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1] y_coords = np.linspace(window[1][0], window[1][1], n + 2, dtype=int)[1:-1] return x_coords, y_coords -def _compute_mean_level(image, x_coords, y_coords, p): +def _compute_mean_level( + image: NpMatrix, + x_coords: NpMatrix, + y_coords: NpMatrix, + p: Optional[float]) -> NpMatrix: if p is None: p = max([2.0, int(0.5 + min(image.shape) / 20.)]) avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0])) @@ -108,7 +126,7 @@ def _compute_mean_level(image, x_coords, y_coords, p): return avg_grey -def _compute_differentials(grey_level_matrix): +def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix: flipped = np.fliplr(grey_level_matrix) right_neighbors = -np.concatenate( ( @@ -152,8 +170,8 @@ def _compute_differentials(grey_level_matrix): lower_right_neighbors])) -def _generate_signature(path_or_image): - im_array = _preprocess_image(path_or_image) +def _generate_signature(content: bytes) -> NpMatrix: + im_array = _preprocess_image(content) image_limits = _crop_image( im_array, lower_percentile=LOWER_PERCENTILE, @@ -169,7 +187,7 @@ def _generate_signature(path_or_image): return np.ravel(diff_matrix).astype('int8') -def _get_words(array, k, n): +def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix: word_positions = np.linspace( 0, array.shape[0], n, endpoint=False).astype('int') assert k <= array.shape[0] @@ -187,21 +205,23 @@ def _get_words(array, k, n): return words -def _words_to_int(word_array): +def _words_to_int(word_array: NpMatrix) -> NpMatrix: width = word_array.shape[1] coding_vector = 3**np.arange(width) return np.dot(word_array + 1, coding_vector) -def _max_contrast(array): +def _max_contrast(array: NpMatrix) -> None: array[array > 0] = 1 array[array < 0] = -1 - return None -def _normalized_distance(_target_array, _vec, nan_value=1.0): - target_array = _target_array.astype(int) - vec = _vec.astype(int) +def _normalized_distance( + target_array: NpMatrix, + vec: NpMatrix, + nan_value: float=1.0) -> List[float]: + target_array = target_array.astype(int) + vec = vec.astype(int) topvec = np.linalg.norm(vec - target_array, axis=1) norm1 = np.linalg.norm(vec, axis=0) norm2 = np.linalg.norm(target_array, axis=1) @@ -210,9 +230,9 @@ def _normalized_distance(_target_array, _vec, nan_value=1.0): return finvec -def _safety_blanket(default_param_factory): - def wrapper_outer(target_function): - def wrapper_inner(*args, **kwargs): +def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable: + def wrapper_outer(target_function: Callable) -> Callable: + def wrapper_inner(*args: Any, **kwargs: Any) -> Any: try: return target_function(*args, **kwargs) except elasticsearch.exceptions.NotFoundError: @@ -226,20 +246,20 @@ def _safety_blanket(default_param_factory): except IOError: raise errors.ProcessingError('Not an image.') except Exception as ex: - raise errors.ThirdPartyError('Unknown error (%s).', ex) + raise errors.ThirdPartyError('Unknown error (%s).' % ex) return wrapper_inner return wrapper_outer class Lookalike: - def __init__(self, score, distance, path): + def __init__(self, score: int, distance: float, path: Any) -> None: self.score = score self.distance = distance self.path = path @_safety_blanket(lambda: None) -def add_image(path, image_content): +def add_image(path: str, image_content: bytes) -> None: assert path assert image_content signature = _generate_signature(image_content) @@ -253,7 +273,7 @@ def add_image(path, image_content): for i in range(MAX_WORDS): record['simple_word_' + str(i)] = words[i].tolist() - es.index( + _get_session().index( index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE, body=record, @@ -261,20 +281,20 @@ def add_image(path, image_content): @_safety_blanket(lambda: None) -def delete_image(path): +def delete_image(path: str) -> None: assert path - es.delete_by_query( + _get_session().delete_by_query( index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE, body={'query': {'term': {'path': path}}}) @_safety_blanket(lambda: []) -def search_by_image(image_content): +def search_by_image(image_content: bytes) -> List[Lookalike]: signature = _generate_signature(image_content) words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) - res = es.search( + res = _get_session().search( index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE, body={ @@ -299,7 +319,7 @@ def search_by_image(image_content): sigs = np.array([x['_source']['signature'] for x in res]) dists = _normalized_distance(sigs, np.array(signature)) - ids = set() + ids = set() # type: Set[int] ret = [] for item, dist in zip(res, dists): id = item['_id'] @@ -314,8 +334,8 @@ def search_by_image(image_content): @_safety_blanket(lambda: None) -def purge(): - es.delete_by_query( +def purge() -> None: + _get_session().delete_by_query( index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE, body={'query': {'match_all': {}}}, @@ -323,10 +343,10 @@ def purge(): @_safety_blanket(lambda: set()) -def get_all_paths(): +def get_all_paths() -> Set[str]: search = ( elasticsearch_dsl.Search( - using=es, + using=_get_session(), index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE) .source(['path'])) diff --git a/server/szurubooru/func/images.py b/server/szurubooru/func/images.py index fdab793b..103a6ff8 100644 --- a/server/szurubooru/func/images.py +++ b/server/szurubooru/func/images.py @@ -1,3 +1,4 @@ +from typing import List import logging import json import shlex @@ -15,23 +16,23 @@ _SCALE_FIT_FMT = \ class Image: - def __init__(self, content): + def __init__(self, content: bytes) -> None: self.content = content self._reload_info() @property - def width(self): + def width(self) -> int: return self.info['streams'][0]['width'] @property - def height(self): + def height(self) -> int: return self.info['streams'][0]['height'] @property - def frames(self): + def frames(self) -> int: return self.info['streams'][0]['nb_read_frames'] - def resize_fill(self, width, height): + def resize_fill(self, width: int, height: int) -> None: cli = [ '-i', '{path}', '-f', 'image2', @@ -53,7 +54,7 @@ class Image: assert self.content self._reload_info() - def to_png(self): + def to_png(self) -> bytes: return self._execute([ '-i', '{path}', '-f', 'image2', @@ -63,7 +64,7 @@ class Image: '-', ]) - def to_jpeg(self): + def to_jpeg(self) -> bytes: return self._execute([ '-f', 'lavfi', '-i', 'color=white:s=%dx%d' % (self.width, self.height), @@ -76,7 +77,7 @@ class Image: '-', ]) - def _execute(self, cli, program='ffmpeg'): + def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes: extension = mime.get_extension(mime.get_mime_type(self.content)) assert extension with util.create_temp_file(suffix='.' + extension) as handle: @@ -99,7 +100,7 @@ class Image: 'Error while processing image.\n' + err.decode('utf-8')) return out - def _reload_info(self): + def _reload_info(self) -> None: self.info = json.loads(self._execute([ '-i', '{path}', '-of', 'json', diff --git a/server/szurubooru/func/mailer.py b/server/szurubooru/func/mailer.py index 94f9c506..76682f11 100644 --- a/server/szurubooru/func/mailer.py +++ b/server/szurubooru/func/mailer.py @@ -3,7 +3,7 @@ import email.mime.text from szurubooru import config -def send_mail(sender, recipient, subject, body): +def send_mail(sender: str, recipient: str, subject: str, body: str) -> None: msg = email.mime.text.MIMEText(body) msg['Subject'] = subject msg['From'] = sender diff --git a/server/szurubooru/func/mime.py b/server/szurubooru/func/mime.py index 2277ed64..c83f744e 100644 --- a/server/szurubooru/func/mime.py +++ b/server/szurubooru/func/mime.py @@ -1,7 +1,8 @@ import re +from typing import Optional -def get_mime_type(content): +def get_mime_type(content: bytes) -> str: if not content: return 'application/octet-stream' @@ -26,7 +27,7 @@ def get_mime_type(content): return 'application/octet-stream' -def get_extension(mime_type): +def get_extension(mime_type: str) -> Optional[str]: extension_map = { 'application/x-shockwave-flash': 'swf', 'image/gif': 'gif', @@ -39,19 +40,19 @@ def get_extension(mime_type): return extension_map.get((mime_type or '').strip().lower(), None) -def is_flash(mime_type): +def is_flash(mime_type: str) -> bool: return mime_type.lower() == 'application/x-shockwave-flash' -def is_video(mime_type): +def is_video(mime_type: str) -> bool: return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm') -def is_image(mime_type): +def is_image(mime_type: str) -> bool: return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif') -def is_animated_gif(content): +def is_animated_gif(content: bytes) -> bool: pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]' return get_mime_type(content) == 'image/gif' \ and len(re.findall(pattern, content)) > 1 diff --git a/server/szurubooru/func/net.py b/server/szurubooru/func/net.py index fb0c427a..a6e18214 100644 --- a/server/szurubooru/func/net.py +++ b/server/szurubooru/func/net.py @@ -2,7 +2,7 @@ import urllib.request from szurubooru import errors -def download(url): +def download(url: str) -> bytes: assert url request = urllib.request.Request(url) request.add_header('Referer', url) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index c942e799..aa4e137f 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -1,8 +1,10 @@ -import datetime -import sqlalchemy -from szurubooru import config, db, errors +from typing import Any, Optional, Tuple, List, Dict, Callable +from datetime import datetime +import sqlalchemy as sa +from szurubooru import config, db, model, errors, rest from szurubooru.func import ( - users, scores, comments, tags, util, mime, images, files, image_hash) + users, scores, comments, tags, util, + mime, images, files, image_hash, serialization) EMPTY_PIXEL = \ @@ -20,7 +22,7 @@ class PostAlreadyFeaturedError(errors.ValidationError): class PostAlreadyUploadedError(errors.ValidationError): - def __init__(self, other_post): + def __init__(self, other_post: model.Post) -> None: super().__init__( 'Post already uploaded (%d)' % other_post.post_id, { @@ -58,30 +60,30 @@ class InvalidPostFlagError(errors.ValidationError): class PostLookalike(image_hash.Lookalike): - def __init__(self, score, distance, post): + def __init__(self, score: int, distance: float, post: model.Post) -> None: super().__init__(score, distance, post.post_id) self.post = post SAFETY_MAP = { - db.Post.SAFETY_SAFE: 'safe', - db.Post.SAFETY_SKETCHY: 'sketchy', - db.Post.SAFETY_UNSAFE: 'unsafe', + model.Post.SAFETY_SAFE: 'safe', + model.Post.SAFETY_SKETCHY: 'sketchy', + model.Post.SAFETY_UNSAFE: 'unsafe', } TYPE_MAP = { - db.Post.TYPE_IMAGE: 'image', - db.Post.TYPE_ANIMATION: 'animation', - db.Post.TYPE_VIDEO: 'video', - db.Post.TYPE_FLASH: 'flash', + model.Post.TYPE_IMAGE: 'image', + model.Post.TYPE_ANIMATION: 'animation', + model.Post.TYPE_VIDEO: 'video', + model.Post.TYPE_FLASH: 'flash', } FLAG_MAP = { - db.Post.FLAG_LOOP: 'loop', + model.Post.FLAG_LOOP: 'loop', } -def get_post_content_url(post): +def get_post_content_url(post: model.Post) -> str: assert post return '%s/posts/%d.%s' % ( config.config['data_url'].rstrip('/'), @@ -89,31 +91,31 @@ def get_post_content_url(post): mime.get_extension(post.mime_type) or 'dat') -def get_post_thumbnail_url(post): +def get_post_thumbnail_url(post: model.Post) -> str: assert post return '%s/generated-thumbnails/%d.jpg' % ( config.config['data_url'].rstrip('/'), post.post_id) -def get_post_content_path(post): +def get_post_content_path(post: model.Post) -> str: assert post assert post.post_id return 'posts/%d.%s' % ( post.post_id, mime.get_extension(post.mime_type) or 'dat') -def get_post_thumbnail_path(post): +def get_post_thumbnail_path(post: model.Post) -> str: assert post return 'generated-thumbnails/%d.jpg' % (post.post_id) -def get_post_thumbnail_backup_path(post): +def get_post_thumbnail_backup_path(post: model.Post) -> str: assert post return 'posts/custom-thumbnails/%d.dat' % (post.post_id) -def serialize_note(note): +def serialize_note(note: model.PostNote) -> rest.Response: assert note return { 'polygon': note.polygon, @@ -121,113 +123,216 @@ def serialize_note(note): } -def serialize_post(post, auth_user, options=None): - return util.serialize_entity( - post, - { - 'id': lambda: post.post_id, - 'version': lambda: post.version, - 'creationTime': lambda: post.creation_time, - 'lastEditTime': lambda: post.last_edit_time, - 'safety': lambda: SAFETY_MAP[post.safety], - 'source': lambda: post.source, - 'type': lambda: TYPE_MAP[post.type], - 'mimeType': lambda: post.mime_type, - 'checksum': lambda: post.checksum, - 'fileSize': lambda: post.file_size, - 'canvasWidth': lambda: post.canvas_width, - 'canvasHeight': lambda: post.canvas_height, - 'contentUrl': lambda: get_post_content_url(post), - 'thumbnailUrl': lambda: get_post_thumbnail_url(post), - 'flags': lambda: post.flags, - 'tags': lambda: [ - tag.names[0].name for tag in tags.sort_tags(post.tags)], - 'relations': lambda: sorted( - { - post['id']: - post for post in [ - serialize_micro_post(rel, auth_user) - for rel in post.relations] - }.values(), - key=lambda post: post['id']), - 'user': lambda: users.serialize_micro_user(post.user, auth_user), - 'score': lambda: post.score, - 'ownScore': lambda: scores.get_score(post, auth_user), - 'ownFavorite': lambda: len([ - user for user in post.favorited_by - if user.user_id == auth_user.user_id] - ) > 0, - 'tagCount': lambda: post.tag_count, - 'favoriteCount': lambda: post.favorite_count, - 'commentCount': lambda: post.comment_count, - 'noteCount': lambda: post.note_count, - 'relationCount': lambda: post.relation_count, - 'featureCount': lambda: post.feature_count, - 'lastFeatureTime': lambda: post.last_feature_time, - 'favoritedBy': lambda: [ - users.serialize_micro_user(rel.user, auth_user) - for rel in post.favorited_by - ], - 'hasCustomThumbnail': - lambda: files.has(get_post_thumbnail_backup_path(post)), - 'notes': lambda: sorted( - [serialize_note(note) for note in post.notes], - key=lambda x: x['polygon']), - 'comments': lambda: [ - comments.serialize_comment(comment, auth_user) - for comment in sorted( - post.comments, - key=lambda comment: comment.creation_time)], - }, - options) +class PostSerializer(serialization.BaseSerializer): + def __init__(self, post: model.Post, auth_user: model.User) -> None: + self.post = post + self.auth_user = auth_user + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'id': self.serialize_id, + 'version': self.serialize_version, + 'creationTime': self.serialize_creation_time, + 'lastEditTime': self.serialize_last_edit_time, + 'safety': self.serialize_safety, + 'source': self.serialize_source, + 'type': self.serialize_type, + 'mimeType': self.serialize_mime, + 'checksum': self.serialize_checksum, + 'fileSize': self.serialize_file_size, + 'canvasWidth': self.serialize_canvas_width, + 'canvasHeight': self.serialize_canvas_height, + 'contentUrl': self.serialize_content_url, + 'thumbnailUrl': self.serialize_thumbnail_url, + 'flags': self.serialize_flags, + 'tags': self.serialize_tags, + 'relations': self.serialize_relations, + 'user': self.serialize_user, + 'score': self.serialize_score, + 'ownScore': self.serialize_own_score, + 'ownFavorite': self.serialize_own_favorite, + 'tagCount': self.serialize_tag_count, + 'favoriteCount': self.serialize_favorite_count, + 'commentCount': self.serialize_comment_count, + 'noteCount': self.serialize_note_count, + 'relationCount': self.serialize_relation_count, + 'featureCount': self.serialize_feature_count, + 'lastFeatureTime': self.serialize_last_feature_time, + 'favoritedBy': self.serialize_favorited_by, + 'hasCustomThumbnail': self.serialize_has_custom_thumbnail, + 'notes': self.serialize_notes, + 'comments': self.serialize_comments, + } + + def serialize_id(self) -> Any: + return self.post.post_id + + def serialize_version(self) -> Any: + return self.post.version + + def serialize_creation_time(self) -> Any: + return self.post.creation_time + + def serialize_last_edit_time(self) -> Any: + return self.post.last_edit_time + + def serialize_safety(self) -> Any: + return SAFETY_MAP[self.post.safety] + + def serialize_source(self) -> Any: + return self.post.source + + def serialize_type(self) -> Any: + return TYPE_MAP[self.post.type] + + def serialize_mime(self) -> Any: + return self.post.mime_type + + def serialize_checksum(self) -> Any: + return self.post.checksum + + def serialize_file_size(self) -> Any: + return self.post.file_size + + def serialize_canvas_width(self) -> Any: + return self.post.canvas_width + + def serialize_canvas_height(self) -> Any: + return self.post.canvas_height + + def serialize_content_url(self) -> Any: + return get_post_content_url(self.post) + + def serialize_thumbnail_url(self) -> Any: + return get_post_thumbnail_url(self.post) + + def serialize_flags(self) -> Any: + return self.post.flags + + def serialize_tags(self) -> Any: + return [tag.names[0].name for tag in tags.sort_tags(self.post.tags)] + + def serialize_relations(self) -> Any: + return sorted( + { + post['id']: post + for post in [ + serialize_micro_post(rel, self.auth_user) + for rel in self.post.relations] + }.values(), + key=lambda post: post['id']) + + def serialize_user(self) -> Any: + return users.serialize_micro_user(self.post.user, self.auth_user) + + def serialize_score(self) -> Any: + return self.post.score + + def serialize_own_score(self) -> Any: + return scores.get_score(self.post, self.auth_user) + + def serialize_own_favorite(self) -> Any: + return len([ + user for user in self.post.favorited_by + if user.user_id == self.auth_user.user_id] + ) > 0 + + def serialize_tag_count(self) -> Any: + return self.post.tag_count + + def serialize_favorite_count(self) -> Any: + return self.post.favorite_count + + def serialize_comment_count(self) -> Any: + return self.post.comment_count + + def serialize_note_count(self) -> Any: + return self.post.note_count + + def serialize_relation_count(self) -> Any: + return self.post.relation_count + + def serialize_feature_count(self) -> Any: + return self.post.feature_count + + def serialize_last_feature_time(self) -> Any: + return self.post.last_feature_time + + def serialize_favorited_by(self) -> Any: + return [ + users.serialize_micro_user(rel.user, self.auth_user) + for rel in self.post.favorited_by + ] + + def serialize_has_custom_thumbnail(self) -> Any: + return files.has(get_post_thumbnail_backup_path(self.post)) + + def serialize_notes(self) -> Any: + return sorted( + [serialize_note(note) for note in self.post.notes], + key=lambda x: x['polygon']) + + def serialize_comments(self) -> Any: + return [ + comments.serialize_comment(comment, self.auth_user) + for comment in sorted( + self.post.comments, + key=lambda comment: comment.creation_time)] -def serialize_micro_post(post, auth_user): +def serialize_post( + post: Optional[model.Post], + auth_user: model.User, + options: List[str]=[]) -> Optional[rest.Response]: + if not post: + return None + return PostSerializer(post, auth_user).serialize(options) + + +def serialize_micro_post( + post: model.Post, auth_user: model.User) -> Optional[rest.Response]: return serialize_post( - post, - auth_user=auth_user, - options=['id', 'thumbnailUrl']) + post, auth_user=auth_user, options=['id', 'thumbnailUrl']) -def get_post_count(): - return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0] +def get_post_count() -> int: + return db.session.query(sa.func.count(model.Post.post_id)).one()[0] -def try_get_post_by_id(post_id): - try: - post_id = int(post_id) - except ValueError: - raise InvalidPostIdError('Invalid post ID: %r.' % post_id) +def try_get_post_by_id(post_id: int) -> Optional[model.Post]: return db.session \ - .query(db.Post) \ - .filter(db.Post.post_id == post_id) \ + .query(model.Post) \ + .filter(model.Post.post_id == post_id) \ .one_or_none() -def get_post_by_id(post_id): +def get_post_by_id(post_id: int) -> model.Post: post = try_get_post_by_id(post_id) if not post: raise PostNotFoundError('Post %r not found.' % post_id) return post -def try_get_current_post_feature(): +def try_get_current_post_feature() -> Optional[model.PostFeature]: return db.session \ - .query(db.PostFeature) \ - .order_by(db.PostFeature.time.desc()) \ + .query(model.PostFeature) \ + .order_by(model.PostFeature.time.desc()) \ .first() -def try_get_featured_post(): +def try_get_featured_post() -> Optional[model.Post]: post_feature = try_get_current_post_feature() return post_feature.post if post_feature else None -def create_post(content, tag_names, user): - post = db.Post() - post.safety = db.Post.SAFETY_SAFE +def create_post( + content: bytes, + tag_names: List[str], + user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]: + post = model.Post() + post.safety = model.Post.SAFETY_SAFE post.user = user - post.creation_time = datetime.datetime.utcnow() + post.creation_time = datetime.utcnow() post.flags = [] post.type = '' @@ -240,7 +345,7 @@ def create_post(content, tag_names, user): return (post, new_tags) -def update_post_safety(post, safety): +def update_post_safety(post: model.Post, safety: str) -> None: assert post safety = util.flip(SAFETY_MAP).get(safety, None) if not safety: @@ -249,30 +354,33 @@ def update_post_safety(post, safety): post.safety = safety -def update_post_source(post, source): +def update_post_source(post: model.Post, source: Optional[str]) -> None: assert post - if util.value_exceeds_column_size(source, db.Post.source): + if util.value_exceeds_column_size(source, model.Post.source): raise InvalidPostSourceError('Source is too long.') - post.source = source + post.source = source or None -@sqlalchemy.events.event.listens_for(db.Post, 'after_insert') -def _after_post_insert(_mapper, _connection, post): +@sa.events.event.listens_for(model.Post, 'after_insert') +def _after_post_insert( + _mapper: Any, _connection: Any, post: model.Post) -> None: _sync_post_content(post) -@sqlalchemy.events.event.listens_for(db.Post, 'after_update') -def _after_post_update(_mapper, _connection, post): +@sa.events.event.listens_for(model.Post, 'after_update') +def _after_post_update( + _mapper: Any, _connection: Any, post: model.Post) -> None: _sync_post_content(post) -@sqlalchemy.events.event.listens_for(db.Post, 'before_delete') -def _before_post_delete(_mapper, _connection, post): +@sa.events.event.listens_for(model.Post, 'before_delete') +def _before_post_delete( + _mapper: Any, _connection: Any, post: model.Post) -> None: if post.post_id: image_hash.delete_image(post.post_id) -def _sync_post_content(post): +def _sync_post_content(post: model.Post) -> None: regenerate_thumb = False if hasattr(post, '__content'): @@ -281,7 +389,7 @@ def _sync_post_content(post): delattr(post, '__content') regenerate_thumb = True if post.post_id and post.type in ( - db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): + model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION): image_hash.delete_image(post.post_id) image_hash.add_image(post.post_id, content) @@ -299,29 +407,29 @@ def _sync_post_content(post): generate_post_thumbnail(post) -def update_post_content(post, content): +def update_post_content(post: model.Post, content: Optional[bytes]) -> None: assert post if not content: raise InvalidPostContentError('Post content missing.') post.mime_type = mime.get_mime_type(content) if mime.is_flash(post.mime_type): - post.type = db.Post.TYPE_FLASH + post.type = model.Post.TYPE_FLASH elif mime.is_image(post.mime_type): if mime.is_animated_gif(content): - post.type = db.Post.TYPE_ANIMATION + post.type = model.Post.TYPE_ANIMATION else: - post.type = db.Post.TYPE_IMAGE + post.type = model.Post.TYPE_IMAGE elif mime.is_video(post.mime_type): - post.type = db.Post.TYPE_VIDEO + post.type = model.Post.TYPE_VIDEO else: raise InvalidPostContentError( 'Unhandled file type: %r' % post.mime_type) post.checksum = util.get_sha1(content) other_post = db.session \ - .query(db.Post) \ - .filter(db.Post.checksum == post.checksum) \ - .filter(db.Post.post_id != post.post_id) \ + .query(model.Post) \ + .filter(model.Post.checksum == post.checksum) \ + .filter(model.Post.post_id != post.post_id) \ .one_or_none() if other_post \ and other_post.post_id \ @@ -343,18 +451,20 @@ def update_post_content(post, content): setattr(post, '__content', content) -def update_post_thumbnail(post, content=None): +def update_post_thumbnail( + post: model.Post, content: Optional[bytes]=None) -> None: assert post setattr(post, '__thumbnail', content) -def generate_post_thumbnail(post): +def generate_post_thumbnail(post: model.Post) -> None: assert post if files.has(get_post_thumbnail_backup_path(post)): content = files.get(get_post_thumbnail_backup_path(post)) else: content = files.get(get_post_content_path(post)) try: + assert content image = images.Image(content) image.resize_fill( int(config.config['thumbnails']['post_width']), @@ -364,14 +474,15 @@ def generate_post_thumbnail(post): files.save(get_post_thumbnail_path(post), EMPTY_PIXEL) -def update_post_tags(post, tag_names): +def update_post_tags( + post: model.Post, tag_names: List[str]) -> List[model.Tag]: assert post existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) post.tags = existing_tags + new_tags return new_tags -def update_post_relations(post, new_post_ids): +def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None: assert post try: new_post_ids = [int(id) for id in new_post_ids] @@ -382,8 +493,8 @@ def update_post_relations(post, new_post_ids): old_post_ids = [int(p.post_id) for p in old_posts] if new_post_ids: new_posts = db.session \ - .query(db.Post) \ - .filter(db.Post.post_id.in_(new_post_ids)) \ + .query(model.Post) \ + .filter(model.Post.post_id.in_(new_post_ids)) \ .all() else: new_posts = [] @@ -402,7 +513,7 @@ def update_post_relations(post, new_post_ids): relation.relations.append(post) -def update_post_notes(post, notes): +def update_post_notes(post: model.Post, notes: Any) -> None: assert post post.notes = [] for note in notes: @@ -433,13 +544,13 @@ def update_post_notes(post, notes): except ValueError: raise InvalidPostNoteError( 'A point in note\'s polygon must be numeric.') - if util.value_exceeds_column_size(note['text'], db.PostNote.text): + if util.value_exceeds_column_size(note['text'], model.PostNote.text): raise InvalidPostNoteError('Note text is too long.') post.notes.append( - db.PostNote(polygon=note['polygon'], text=str(note['text']))) + model.PostNote(polygon=note['polygon'], text=str(note['text']))) -def update_post_flags(post, flags): +def update_post_flags(post: model.Post, flags: List[str]) -> None: assert post target_flags = [] for flag in flags: @@ -451,88 +562,95 @@ def update_post_flags(post, flags): post.flags = target_flags -def feature_post(post, user): +def feature_post(post: model.Post, user: Optional[model.User]) -> None: assert post - post_feature = db.PostFeature() - post_feature.time = datetime.datetime.utcnow() + post_feature = model.PostFeature() + post_feature.time = datetime.utcnow() post_feature.post = post post_feature.user = user db.session.add(post_feature) -def delete(post): +def delete(post: model.Post) -> None: assert post db.session.delete(post) -def merge_posts(source_post, target_post, replace_content): +def merge_posts( + source_post: model.Post, + target_post: model.Post, + replace_content: bool) -> None: assert source_post assert target_post if source_post.post_id == target_post.post_id: raise InvalidPostRelationError('Cannot merge post with itself.') - def merge_tables(table, anti_dup_func, source_post_id, target_post_id): + def merge_tables( + table: model.Base, + anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]], + source_post_id: int, + target_post_id: int) -> None: alias1 = table - alias2 = sqlalchemy.orm.util.aliased(table) + alias2 = sa.orm.util.aliased(table) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.post_id == source_post_id)) if anti_dup_func is not None: update_stmt = ( update_stmt .where( - ~sqlalchemy.exists() + ~sa.exists() .where(anti_dup_func(alias1, alias2)) .where(alias2.post_id == target_post_id))) update_stmt = update_stmt.values(post_id=target_post_id) db.session.execute(update_stmt) - def merge_tags(source_post_id, target_post_id): + def merge_tags(source_post_id: int, target_post_id: int) -> None: merge_tables( - db.PostTag, + model.PostTag, lambda alias1, alias2: alias1.tag_id == alias2.tag_id, source_post_id, target_post_id) - def merge_scores(source_post_id, target_post_id): + def merge_scores(source_post_id: int, target_post_id: int) -> None: merge_tables( - db.PostScore, + model.PostScore, lambda alias1, alias2: alias1.user_id == alias2.user_id, source_post_id, target_post_id) - def merge_favorites(source_post_id, target_post_id): + def merge_favorites(source_post_id: int, target_post_id: int) -> None: merge_tables( - db.PostFavorite, + model.PostFavorite, lambda alias1, alias2: alias1.user_id == alias2.user_id, source_post_id, target_post_id) - def merge_comments(source_post_id, target_post_id): - merge_tables(db.Comment, None, source_post_id, target_post_id) + def merge_comments(source_post_id: int, target_post_id: int) -> None: + merge_tables(model.Comment, None, source_post_id, target_post_id) - def merge_relations(source_post_id, target_post_id): - alias1 = db.PostRelation - alias2 = sqlalchemy.orm.util.aliased(db.PostRelation) + def merge_relations(source_post_id: int, target_post_id: int) -> None: + alias1 = model.PostRelation + alias2 = sa.orm.util.aliased(model.PostRelation) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.parent_id == source_post_id) .where(alias1.child_id != target_post_id) .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias2.child_id == alias1.child_id) .where(alias2.parent_id == target_post_id)) .values(parent_id=target_post_id)) db.session.execute(update_stmt) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.child_id == source_post_id) .where(alias1.parent_id != target_post_id) .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias2.parent_id == alias1.parent_id) .where(alias2.child_id == target_post_id)) .values(child_id=target_post_id)) @@ -553,15 +671,15 @@ def merge_posts(source_post, target_post, replace_content): update_post_content(target_post, content) -def search_by_image_exact(image_content): +def search_by_image_exact(image_content: bytes) -> Optional[model.Post]: checksum = util.get_sha1(image_content) return db.session \ - .query(db.Post) \ - .filter(db.Post.checksum == checksum) \ + .query(model.Post) \ + .filter(model.Post.checksum == checksum) \ .one_or_none() -def search_by_image(image_content): +def search_by_image(image_content: bytes) -> List[PostLookalike]: ret = [] for result in image_hash.search_by_image(image_content): ret.append(PostLookalike( @@ -571,24 +689,24 @@ def search_by_image(image_content): return ret -def populate_reverse_search(): +def populate_reverse_search() -> None: excluded_post_ids = image_hash.get_all_paths() post_ids_to_hash = ( db.session - .query(db.Post.post_id) + .query(model.Post.post_id) .filter( - (db.Post.type == db.Post.TYPE_IMAGE) | - (db.Post.type == db.Post.TYPE_ANIMATION)) - .filter(~db.Post.post_id.in_(excluded_post_ids)) - .order_by(db.Post.post_id.asc()) + (model.Post.type == model.Post.TYPE_IMAGE) | + (model.Post.type == model.Post.TYPE_ANIMATION)) + .filter(~model.Post.post_id.in_(excluded_post_ids)) + .order_by(model.Post.post_id.asc()) .all()) for post_ids_chunk in util.chunks(post_ids_to_hash, 100): posts_chunk = ( db.session - .query(db.Post) - .filter(db.Post.post_id.in_(post_ids_chunk)) + .query(model.Post) + .filter(model.Post.post_id.in_(post_ids_chunk)) .all()) for post in posts_chunk: content_path = get_post_content_path(post) diff --git a/server/szurubooru/func/scores.py b/server/szurubooru/func/scores.py index a42961f2..fde279eb 100644 --- a/server/szurubooru/func/scores.py +++ b/server/szurubooru/func/scores.py @@ -1,5 +1,6 @@ import datetime -from szurubooru import db, errors +from typing import Any, Tuple, Callable +from szurubooru import db, model, errors class InvalidScoreTargetError(errors.ValidationError): @@ -10,22 +11,23 @@ class InvalidScoreValueError(errors.ValidationError): pass -def _get_table_info(entity): +def _get_table_info( + entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]: assert entity - resource_type, _, _ = db.util.get_resource_info(entity) + resource_type, _, _ = model.util.get_resource_info(entity) if resource_type == 'post': - return db.PostScore, lambda table: table.post_id + return model.PostScore, lambda table: table.post_id elif resource_type == 'comment': - return db.CommentScore, lambda table: table.comment_id + return model.CommentScore, lambda table: table.comment_id raise InvalidScoreTargetError() -def _get_score_entity(entity, user): +def _get_score_entity(entity: model.Base, user: model.User) -> model.Base: assert user - return db.util.get_aux_entity(db.session, _get_table_info, entity, user) + return model.util.get_aux_entity(db.session, _get_table_info, entity, user) -def delete_score(entity, user): +def delete_score(entity: model.Base, user: model.User) -> None: assert entity assert user score_entity = _get_score_entity(entity, user) @@ -33,7 +35,7 @@ def delete_score(entity, user): db.session.delete(score_entity) -def get_score(entity, user): +def get_score(entity: model.Base, user: model.User) -> int: assert entity assert user table, get_column = _get_table_info(entity) @@ -45,7 +47,7 @@ def get_score(entity, user): return row[0] if row else 0 -def set_score(entity, user, score): +def set_score(entity: model.Base, user: model.User, score: int) -> None: from szurubooru.func import favorites assert entity assert user diff --git a/server/szurubooru/func/serialization.py b/server/szurubooru/func/serialization.py new file mode 100644 index 00000000..df78959f --- /dev/null +++ b/server/szurubooru/func/serialization.py @@ -0,0 +1,27 @@ +from typing import Any, Optional, List, Dict, Callable +from szurubooru import db, model, rest, errors + + +def get_serialization_options(ctx: rest.Context) -> List[str]: + return ctx.get_param_as_list('fields', default=[]) + + +class BaseSerializer: + _fields = {} # type: Dict[str, Callable[[model.Base], Any]] + + def serialize(self, options: List[str]) -> Any: + field_factories = self._serializers() + if not options: + options = list(field_factories.keys()) + ret = {} + for key in options: + if key not in field_factories: + raise errors.ValidationError( + 'Invalid key: %r. Valid keys: %r.' % ( + key, list(sorted(field_factories.keys())))) + factory = field_factories[key] + ret[key] = factory() + return ret + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + raise NotImplementedError() diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index f7efda9e..240c3bce 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -1,9 +1,10 @@ +from typing import Any, Optional, Dict, Callable from datetime import datetime -from szurubooru import db +from szurubooru import db, model from szurubooru.func import diff, users -def get_tag_category_snapshot(category): +def get_tag_category_snapshot(category: model.TagCategory) -> Dict[str, Any]: assert category return { 'name': category.name, @@ -12,7 +13,7 @@ def get_tag_category_snapshot(category): } -def get_tag_snapshot(tag): +def get_tag_snapshot(tag: model.Tag) -> Dict[str, Any]: assert tag return { 'names': [tag_name.name for tag_name in tag.names], @@ -22,7 +23,7 @@ def get_tag_snapshot(tag): } -def get_post_snapshot(post): +def get_post_snapshot(post: model.Post) -> Dict[str, Any]: assert post return { 'source': post.source, @@ -45,10 +46,11 @@ _snapshot_factories = { 'tag_category': lambda entity: get_tag_category_snapshot(entity), 'tag': lambda entity: get_tag_snapshot(entity), 'post': lambda entity: get_post_snapshot(entity), -} +} # type: Dict[model.Base, Callable[[model.Base], Dict[str ,Any]]] -def serialize_snapshot(snapshot, auth_user): +def serialize_snapshot( + snapshot: model.Snapshot, auth_user: model.User) -> Dict[str, Any]: assert snapshot return { 'operation': snapshot.operation, @@ -60,11 +62,14 @@ def serialize_snapshot(snapshot, auth_user): } -def _create(operation, entity, auth_user): +def _create( + operation: str, + entity: model.Base, + auth_user: Optional[model.User]) -> model.Snapshot: resource_type, resource_pkey, resource_name = ( - db.util.get_resource_info(entity)) + model.util.get_resource_info(entity)) - snapshot = db.Snapshot() + snapshot = model.Snapshot() snapshot.creation_time = datetime.utcnow() snapshot.operation = operation snapshot.resource_type = resource_type @@ -74,33 +79,33 @@ def _create(operation, entity, auth_user): return snapshot -def create(entity, auth_user): +def create(entity: model.Base, auth_user: Optional[model.User]) -> None: assert entity - snapshot = _create(db.Snapshot.OPERATION_CREATED, entity, auth_user) + snapshot = _create(model.Snapshot.OPERATION_CREATED, entity, auth_user) snapshot_factory = _snapshot_factories[snapshot.resource_type] snapshot.data = snapshot_factory(entity) db.session.add(snapshot) # pylint: disable=protected-access -def modify(entity, auth_user): +def modify(entity: model.Base, auth_user: Optional[model.User]) -> None: assert entity - model = next( + table = next( ( - model - for model in db.Base._decl_class_registry.values() - if hasattr(model, '__table__') - and model.__table__.fullname == entity.__table__.fullname + cls + for cls in model.Base._decl_class_registry.values() + if hasattr(cls, '__table__') + and cls.__table__.fullname == entity.__table__.fullname ), None) - assert model + assert table - snapshot = _create(db.Snapshot.OPERATION_MODIFIED, entity, auth_user) + snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user) snapshot_factory = _snapshot_factories[snapshot.resource_type] detached_session = db.sessionmaker() - detached_entity = detached_session.query(model).get(snapshot.resource_pkey) + detached_entity = detached_session.query(table).get(snapshot.resource_pkey) assert detached_entity, 'Entity not found in DB, have you committed it?' detached_snapshot = snapshot_factory(detached_entity) detached_session.close() @@ -113,19 +118,23 @@ def modify(entity, auth_user): db.session.add(snapshot) -def delete(entity, auth_user): +def delete(entity: model.Base, auth_user: Optional[model.User]) -> None: assert entity - snapshot = _create(db.Snapshot.OPERATION_DELETED, entity, auth_user) + snapshot = _create(model.Snapshot.OPERATION_DELETED, entity, auth_user) snapshot_factory = _snapshot_factories[snapshot.resource_type] snapshot.data = snapshot_factory(entity) db.session.add(snapshot) -def merge(source_entity, target_entity, auth_user): +def merge( + source_entity: model.Base, + target_entity: model.Base, + auth_user: Optional[model.User]) -> None: assert source_entity assert target_entity - snapshot = _create(db.Snapshot.OPERATION_MERGED, source_entity, auth_user) + snapshot = _create( + model.Snapshot.OPERATION_MERGED, source_entity, auth_user) resource_type, _resource_pkey, resource_name = ( - db.util.get_resource_info(target_entity)) + model.util.get_resource_info(target_entity)) snapshot.data = [resource_type, resource_name] db.session.add(snapshot) diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index a9169dec..41c9c928 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -1,7 +1,8 @@ import re -import sqlalchemy -from szurubooru import config, db, errors -from szurubooru.func import util, cache +from typing import Any, Optional, Dict, List, Callable +import sqlalchemy as sa +from szurubooru import config, db, model, errors, rest +from szurubooru.func import util, serialization, cache DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category' @@ -27,28 +28,52 @@ class InvalidTagCategoryColorError(errors.ValidationError): pass -def _verify_name_validity(name): +def _verify_name_validity(name: str) -> None: name_regex = config.config['tag_category_name_regex'] if not re.match(name_regex, name): raise InvalidTagCategoryNameError( 'Name must satisfy regex %r.' % name_regex) -def serialize_category(category, options=None): - return util.serialize_entity( - category, - { - 'name': lambda: category.name, - 'version': lambda: category.version, - 'color': lambda: category.color, - 'usages': lambda: category.tag_count, - 'default': lambda: category.default, - }, - options) +class TagCategorySerializer(serialization.BaseSerializer): + def __init__(self, category: model.TagCategory) -> None: + self.category = category + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'name': self.serialize_name, + 'version': self.serialize_version, + 'color': self.serialize_color, + 'usages': self.serialize_usages, + 'default': self.serialize_default, + } + + def serialize_name(self) -> Any: + return self.category.name + + def serialize_version(self) -> Any: + return self.category.version + + def serialize_color(self) -> Any: + return self.category.color + + def serialize_usages(self) -> Any: + return self.category.tag_count + + def serialize_default(self) -> Any: + return self.category.default -def create_category(name, color): - category = db.TagCategory() +def serialize_category( + category: Optional[model.TagCategory], + options: List[str]=[]) -> Optional[rest.Response]: + if not category: + return None + return TagCategorySerializer(category).serialize(options) + + +def create_category(name: str, color: str) -> model.TagCategory: + category = model.TagCategory() update_category_name(category, name) update_category_color(category, color) if not get_all_categories(): @@ -56,64 +81,66 @@ def create_category(name, color): return category -def update_category_name(category, name): +def update_category_name(category: model.TagCategory, name: str) -> None: assert category if not name: raise InvalidTagCategoryNameError('Name cannot be empty.') - expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower() + expr = sa.func.lower(model.TagCategory.name) == name.lower() if category.tag_category_id: expr = expr & ( - db.TagCategory.tag_category_id != category.tag_category_id) - already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0 + model.TagCategory.tag_category_id != category.tag_category_id) + already_exists = ( + db.session.query(model.TagCategory).filter(expr).count() > 0) if already_exists: raise TagCategoryAlreadyExistsError( 'A category with this name already exists.') - if util.value_exceeds_column_size(name, db.TagCategory.name): + if util.value_exceeds_column_size(name, model.TagCategory.name): raise InvalidTagCategoryNameError('Name is too long.') _verify_name_validity(name) category.name = name cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) -def update_category_color(category, color): +def update_category_color(category: model.TagCategory, color: str) -> None: assert category if not color: raise InvalidTagCategoryColorError('Color cannot be empty.') if not re.match(r'^#?[0-9a-z]+$', color): raise InvalidTagCategoryColorError('Invalid color.') - if util.value_exceeds_column_size(color, db.TagCategory.color): + if util.value_exceeds_column_size(color, model.TagCategory.color): raise InvalidTagCategoryColorError('Color is too long.') category.color = color -def try_get_category_by_name(name, lock=False): +def try_get_category_by_name( + name: str, lock: bool=False) -> Optional[model.TagCategory]: query = db.session \ - .query(db.TagCategory) \ - .filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) + .query(model.TagCategory) \ + .filter(sa.func.lower(model.TagCategory.name) == name.lower()) if lock: query = query.with_lockmode('update') return query.one_or_none() -def get_category_by_name(name, lock=False): +def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory: category = try_get_category_by_name(name, lock) if not category: raise TagCategoryNotFoundError('Tag category %r not found.' % name) return category -def get_all_category_names(): - return [row[0] for row in db.session.query(db.TagCategory.name).all()] +def get_all_category_names() -> List[str]: + return [row[0] for row in db.session.query(model.TagCategory.name).all()] -def get_all_categories(): - return db.session.query(db.TagCategory).all() +def get_all_categories() -> List[model.TagCategory]: + return db.session.query(model.TagCategory).all() -def try_get_default_category(lock=False): +def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]: query = db.session \ - .query(db.TagCategory) \ - .filter(db.TagCategory.default) + .query(model.TagCategory) \ + .filter(model.TagCategory.default) if lock: query = query.with_lockmode('update') category = query.first() @@ -121,22 +148,22 @@ def try_get_default_category(lock=False): # category, get the first record available. if not category: query = db.session \ - .query(db.TagCategory) \ - .order_by(db.TagCategory.tag_category_id.asc()) + .query(model.TagCategory) \ + .order_by(model.TagCategory.tag_category_id.asc()) if lock: query = query.with_lockmode('update') category = query.first() return category -def get_default_category(lock=False): +def get_default_category(lock: bool=False) -> model.TagCategory: category = try_get_default_category(lock) if not category: raise TagCategoryNotFoundError('No tag category created yet.') return category -def get_default_category_name(): +def get_default_category_name() -> str: if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY): return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY) default_category = get_default_category() @@ -145,7 +172,7 @@ def get_default_category_name(): return default_category_name -def set_default_category(category): +def set_default_category(category: model.TagCategory) -> None: assert category old_category = try_get_default_category(lock=True) if old_category: @@ -156,7 +183,7 @@ def set_default_category(category): cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) -def delete_category(category): +def delete_category(category: model.TagCategory) -> None: assert category if len(get_all_category_names()) == 1: raise TagCategoryIsInUseError('Cannot delete the last category.') diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 1665282b..fb043245 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -1,10 +1,11 @@ -import datetime import json import os import re -import sqlalchemy -from szurubooru import config, db, errors -from szurubooru.func import util, tag_categories +from typing import Any, Optional, Tuple, List, Dict, Callable +from datetime import datetime +import sqlalchemy as sa +from szurubooru import config, db, model, errors, rest +from szurubooru.func import util, tag_categories, serialization class TagNotFoundError(errors.NotFoundError): @@ -35,31 +36,32 @@ class InvalidTagDescriptionError(errors.ValidationError): pass -def _verify_name_validity(name): - if util.value_exceeds_column_size(name, db.TagName.name): +def _verify_name_validity(name: str) -> None: + if util.value_exceeds_column_size(name, model.TagName.name): raise InvalidTagNameError('Name is too long.') name_regex = config.config['tag_name_regex'] if not re.match(name_regex, name): raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex) -def _get_names(tag): +def _get_names(tag: model.Tag) -> List[str]: assert tag return [tag_name.name for tag_name in tag.names] -def _lower_list(names): +def _lower_list(names: List[str]) -> List[str]: return [name.lower() for name in names] -def _check_name_intersection(names1, names2, case_sensitive): +def _check_name_intersection( + names1: List[str], names2: List[str], case_sensitive: bool) -> bool: if not case_sensitive: names1 = _lower_list(names1) names2 = _lower_list(names2) return len(set(names1).intersection(names2)) > 0 -def sort_tags(tags): +def sort_tags(tags: List[model.Tag]) -> List[model.Tag]: default_category_name = tag_categories.get_default_category_name() return sorted( tags, @@ -70,35 +72,70 @@ def sort_tags(tags): ) -def serialize_tag(tag, options=None): - return util.serialize_entity( - tag, - { - 'names': lambda: [tag_name.name for tag_name in tag.names], - 'category': lambda: tag.category.name, - 'version': lambda: tag.version, - 'description': lambda: tag.description, - 'creationTime': lambda: tag.creation_time, - 'lastEditTime': lambda: tag.last_edit_time, - 'usages': lambda: tag.post_count, - 'suggestions': lambda: [ - relation.names[0].name - for relation in sort_tags(tag.suggestions)], - 'implications': lambda: [ - relation.names[0].name - for relation in sort_tags(tag.implications)], - }, - options) +class TagSerializer(serialization.BaseSerializer): + def __init__(self, tag: model.Tag) -> None: + self.tag = tag + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'names': self.serialize_names, + 'category': self.serialize_category, + 'version': self.serialize_version, + 'description': self.serialize_description, + 'creationTime': self.serialize_creation_time, + 'lastEditTime': self.serialize_last_edit_time, + 'usages': self.serialize_usages, + 'suggestions': self.serialize_suggestions, + 'implications': self.serialize_implications, + } + + def serialize_names(self) -> Any: + return [tag_name.name for tag_name in self.tag.names] + + def serialize_category(self) -> Any: + return self.tag.category.name + + def serialize_version(self) -> Any: + return self.tag.version + + def serialize_description(self) -> Any: + return self.tag.description + + def serialize_creation_time(self) -> Any: + return self.tag.creation_time + + def serialize_last_edit_time(self) -> Any: + return self.tag.last_edit_time + + def serialize_usages(self) -> Any: + return self.tag.post_count + + def serialize_suggestions(self) -> Any: + return [ + relation.names[0].name + for relation in sort_tags(self.tag.suggestions)] + + def serialize_implications(self) -> Any: + return [ + relation.names[0].name + for relation in sort_tags(self.tag.implications)] -def export_to_json(): - tags = {} - categories = {} +def serialize_tag( + tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]: + if not tag: + return None + return TagSerializer(tag).serialize(options) + + +def export_to_json() -> None: + tags = {} # type: Dict[int, Any] + categories = {} # type: Dict[int, Any] for result in db.session.query( - db.TagCategory.tag_category_id, - db.TagCategory.name, - db.TagCategory.color).all(): + model.TagCategory.tag_category_id, + model.TagCategory.name, + model.TagCategory.color).all(): categories[result[0]] = { 'name': result[1], 'color': result[2], @@ -106,8 +143,8 @@ def export_to_json(): for result in ( db.session - .query(db.TagName.tag_id, db.TagName.name) - .order_by(db.TagName.order) + .query(model.TagName.tag_id, model.TagName.name) + .order_by(model.TagName.order) .all()): if not result[0] in tags: tags[result[0]] = {'names': []} @@ -115,8 +152,10 @@ def export_to_json(): for result in ( db.session - .query(db.TagSuggestion.parent_id, db.TagName.name) - .join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) + .query(model.TagSuggestion.parent_id, model.TagName.name) + .join( + model.TagName, + model.TagName.tag_id == model.TagSuggestion.child_id) .all()): if 'suggestions' not in tags[result[0]]: tags[result[0]]['suggestions'] = [] @@ -124,17 +163,19 @@ def export_to_json(): for result in ( db.session - .query(db.TagImplication.parent_id, db.TagName.name) - .join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) + .query(model.TagImplication.parent_id, model.TagName.name) + .join( + model.TagName, + model.TagName.tag_id == model.TagImplication.child_id) .all()): if 'implications' not in tags[result[0]]: tags[result[0]]['implications'] = [] tags[result[0]]['implications'].append(result[1]) for result in db.session.query( - db.Tag.tag_id, - db.Tag.category_id, - db.Tag.post_count).all(): + model.Tag.tag_id, + model.Tag.category_id, + model.Tag.post_count).all(): tags[result[0]]['category'] = categories[result[1]]['name'] tags[result[0]]['usages'] = result[2] @@ -148,33 +189,34 @@ def export_to_json(): handle.write(json.dumps(output, separators=(',', ':'))) -def try_get_tag_by_name(name): +def try_get_tag_by_name(name: str) -> Optional[model.Tag]: return ( db.session - .query(db.Tag) - .join(db.TagName) - .filter(sqlalchemy.func.lower(db.TagName.name) == name.lower()) + .query(model.Tag) + .join(model.TagName) + .filter(sa.func.lower(model.TagName.name) == name.lower()) .one_or_none()) -def get_tag_by_name(name): +def get_tag_by_name(name: str) -> model.Tag: tag = try_get_tag_by_name(name) if not tag: raise TagNotFoundError('Tag %r not found.' % name) return tag -def get_tags_by_names(names): +def get_tags_by_names(names: List[str]) -> List[model.Tag]: names = util.icase_unique(names) if len(names) == 0: return [] - expr = sqlalchemy.sql.false() + expr = sa.sql.false() for name in names: - expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) - return db.session.query(db.Tag).join(db.TagName).filter(expr).all() + expr = expr | (sa.func.lower(model.TagName.name) == name.lower()) + return db.session.query(model.Tag).join(model.TagName).filter(expr).all() -def get_or_create_tags_by_names(names): +def get_or_create_tags_by_names( + names: List[str]) -> Tuple[List[model.Tag], List[model.Tag]]: names = util.icase_unique(names) existing_tags = get_tags_by_names(names) new_tags = [] @@ -197,86 +239,87 @@ def get_or_create_tags_by_names(names): return existing_tags, new_tags -def get_tag_siblings(tag): +def get_tag_siblings(tag: model.Tag) -> List[model.Tag]: assert tag - tag_alias = sqlalchemy.orm.aliased(db.Tag) - pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) - pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) + tag_alias = sa.orm.aliased(model.Tag) + pt_alias1 = sa.orm.aliased(model.PostTag) + pt_alias2 = sa.orm.aliased(model.PostTag) result = ( db.session - .query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) + .query(tag_alias, sa.func.count(pt_alias2.post_id)) .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) .filter(pt_alias2.tag_id == tag.tag_id) .filter(pt_alias1.tag_id != tag.tag_id) .group_by(tag_alias.tag_id) - .order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) + .order_by(sa.func.count(pt_alias2.post_id).desc()) .limit(50)) return result -def delete(source_tag): +def delete(source_tag: model.Tag) -> None: assert source_tag db.session.execute( - sqlalchemy.sql.expression.delete(db.TagSuggestion) - .where(db.TagSuggestion.child_id == source_tag.tag_id)) + sa.sql.expression.delete(model.TagSuggestion) + .where(model.TagSuggestion.child_id == source_tag.tag_id)) db.session.execute( - sqlalchemy.sql.expression.delete(db.TagImplication) - .where(db.TagImplication.child_id == source_tag.tag_id)) + sa.sql.expression.delete(model.TagImplication) + .where(model.TagImplication.child_id == source_tag.tag_id)) db.session.delete(source_tag) -def merge_tags(source_tag, target_tag): +def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None: assert source_tag assert target_tag if source_tag.tag_id == target_tag.tag_id: raise InvalidTagRelationError('Cannot merge tag with itself.') - def merge_posts(source_tag_id, target_tag_id): - alias1 = db.PostTag - alias2 = sqlalchemy.orm.util.aliased(db.PostTag) + def merge_posts(source_tag_id: int, target_tag_id: int) -> None: + alias1 = model.PostTag + alias2 = sa.orm.util.aliased(model.PostTag) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.tag_id == source_tag_id)) update_stmt = ( update_stmt .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias1.post_id == alias2.post_id) .where(alias2.tag_id == target_tag_id))) update_stmt = update_stmt.values(tag_id=target_tag_id) db.session.execute(update_stmt) - def merge_relations(table, source_tag_id, target_tag_id): + def merge_relations( + table: model.Base, source_tag_id: int, target_tag_id: int) -> None: alias1 = table - alias2 = sqlalchemy.orm.util.aliased(table) + alias2 = sa.orm.util.aliased(table) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.parent_id == source_tag_id) .where(alias1.child_id != target_tag_id) .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias2.child_id == alias1.child_id) .where(alias2.parent_id == target_tag_id)) .values(parent_id=target_tag_id)) db.session.execute(update_stmt) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.child_id == source_tag_id) .where(alias1.parent_id != target_tag_id) .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias2.parent_id == alias1.parent_id) .where(alias2.child_id == target_tag_id)) .values(child_id=target_tag_id)) db.session.execute(update_stmt) - def merge_suggestions(source_tag_id, target_tag_id): - merge_relations(db.TagSuggestion, source_tag_id, target_tag_id) + def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None: + merge_relations(model.TagSuggestion, source_tag_id, target_tag_id) - def merge_implications(source_tag_id, target_tag_id): - merge_relations(db.TagImplication, source_tag_id, target_tag_id) + def merge_implications(source_tag_id: int, target_tag_id: int) -> None: + merge_relations(model.TagImplication, source_tag_id, target_tag_id) merge_posts(source_tag.tag_id, target_tag.tag_id) merge_suggestions(source_tag.tag_id, target_tag.tag_id) @@ -284,9 +327,13 @@ def merge_tags(source_tag, target_tag): delete(source_tag) -def create_tag(names, category_name, suggestions, implications): - tag = db.Tag() - tag.creation_time = datetime.datetime.utcnow() +def create_tag( + names: List[str], + category_name: str, + suggestions: List[str], + implications: List[str]) -> model.Tag: + tag = model.Tag() + tag.creation_time = datetime.utcnow() update_tag_names(tag, names) update_tag_category_name(tag, category_name) update_tag_suggestions(tag, suggestions) @@ -294,12 +341,12 @@ def create_tag(names, category_name, suggestions, implications): return tag -def update_tag_category_name(tag, category_name): +def update_tag_category_name(tag: model.Tag, category_name: str) -> None: assert tag tag.category = tag_categories.get_category_by_name(category_name) -def update_tag_names(tag, names): +def update_tag_names(tag: model.Tag, names: List[str]) -> None: # sanitize assert tag names = util.icase_unique([name for name in names if name]) @@ -309,12 +356,12 @@ def update_tag_names(tag, names): _verify_name_validity(name) # check for existing tags - expr = sqlalchemy.sql.false() + expr = sa.sql.false() for name in names: - expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) + expr = expr | (sa.func.lower(model.TagName.name) == name.lower()) if tag.tag_id: - expr = expr & (db.TagName.tag_id != tag.tag_id) - existing_tags = db.session.query(db.TagName).filter(expr).all() + expr = expr & (model.TagName.tag_id != tag.tag_id) + existing_tags = db.session.query(model.TagName).filter(expr).all() if len(existing_tags): raise TagAlreadyExistsError( 'One of names is already used by another tag.') @@ -326,7 +373,7 @@ def update_tag_names(tag, names): # add wanted items for name in names: if not _check_name_intersection(_get_names(tag), [name], True): - tag.names.append(db.TagName(name, None)) + tag.names.append(model.TagName(name, -1)) # set alias order to match the request for i, name in enumerate(names): @@ -336,7 +383,7 @@ def update_tag_names(tag, names): # TODO: what to do with relations that do not yet exist? -def update_tag_implications(tag, relations): +def update_tag_implications(tag: model.Tag, relations: List[str]) -> None: assert tag if _check_name_intersection(_get_names(tag), relations, False): raise InvalidTagRelationError('Tag cannot imply itself.') @@ -344,15 +391,15 @@ def update_tag_implications(tag, relations): # TODO: what to do with relations that do not yet exist? -def update_tag_suggestions(tag, relations): +def update_tag_suggestions(tag: model.Tag, relations: List[str]) -> None: assert tag if _check_name_intersection(_get_names(tag), relations, False): raise InvalidTagRelationError('Tag cannot suggest itself.') tag.suggestions = get_tags_by_names(relations) -def update_tag_description(tag, description): +def update_tag_description(tag: model.Tag, description: str) -> None: assert tag - if util.value_exceeds_column_size(description, db.Tag.description): + if util.value_exceeds_column_size(description, model.Tag.description): raise InvalidTagDescriptionError('Description is too long.') - tag.description = description + tag.description = description or None diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index 5547bbae..fd0c6240 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -1,8 +1,9 @@ -import datetime import re -from sqlalchemy import func -from szurubooru import config, db, errors -from szurubooru.func import auth, util, files, images +from typing import Any, Optional, Union, List, Dict, Callable +from datetime import datetime +import sqlalchemy as sa +from szurubooru import config, db, model, errors, rest +from szurubooru.func import auth, util, serialization, files, images class UserNotFoundError(errors.NotFoundError): @@ -33,11 +34,11 @@ class InvalidAvatarError(errors.ValidationError): pass -def get_avatar_path(user_name): +def get_avatar_path(user_name: str) -> str: return 'avatars/' + user_name.lower() + '.png' -def get_avatar_url(user): +def get_avatar_url(user: model.User) -> str: assert user if user.avatar_style == user.AVATAR_GRAVATAR: assert user.email or user.name @@ -49,7 +50,10 @@ def get_avatar_url(user): config.config['data_url'].rstrip('/'), user.name.lower()) -def get_email(user, auth_user, force_show_email): +def get_email( + user: model.User, + auth_user: model.User, + force_show_email: bool) -> Union[bool, str]: assert user assert auth_user if not force_show_email \ @@ -59,7 +63,8 @@ def get_email(user, auth_user, force_show_email): return user.email -def get_liked_post_count(user, auth_user): +def get_liked_post_count( + user: model.User, auth_user: model.User) -> Union[bool, int]: assert user assert auth_user if auth_user.user_id != user.user_id: @@ -67,7 +72,8 @@ def get_liked_post_count(user, auth_user): return user.liked_post_count -def get_disliked_post_count(user, auth_user): +def get_disliked_post_count( + user: model.User, auth_user: model.User) -> Union[bool, int]: assert user assert auth_user if auth_user.user_id != user.user_id: @@ -75,91 +81,144 @@ def get_disliked_post_count(user, auth_user): return user.disliked_post_count -def serialize_user(user, auth_user, options=None, force_show_email=False): - return util.serialize_entity( - user, - { - 'name': lambda: user.name, - 'creationTime': lambda: user.creation_time, - 'lastLoginTime': lambda: user.last_login_time, - 'version': lambda: user.version, - 'rank': lambda: user.rank, - 'avatarStyle': lambda: user.avatar_style, - 'avatarUrl': lambda: get_avatar_url(user), - 'commentCount': lambda: user.comment_count, - 'uploadedPostCount': lambda: user.post_count, - 'favoritePostCount': lambda: user.favorite_post_count, - 'likedPostCount': - lambda: get_liked_post_count(user, auth_user), - 'dislikedPostCount': - lambda: get_disliked_post_count(user, auth_user), - 'email': - lambda: get_email(user, auth_user, force_show_email), - }, - options) +class UserSerializer(serialization.BaseSerializer): + def __init__( + self, + user: model.User, + auth_user: model.User, + force_show_email: bool=False) -> None: + self.user = user + self.auth_user = auth_user + self.force_show_email = force_show_email + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'name': self.serialize_name, + 'creationTime': self.serialize_creation_time, + 'lastLoginTime': self.serialize_last_login_time, + 'version': self.serialize_version, + 'rank': self.serialize_rank, + 'avatarStyle': self.serialize_avatar_style, + 'avatarUrl': self.serialize_avatar_url, + 'commentCount': self.serialize_comment_count, + 'uploadedPostCount': self.serialize_uploaded_post_count, + 'favoritePostCount': self.serialize_favorite_post_count, + 'likedPostCount': self.serialize_liked_post_count, + 'dislikedPostCount': self.serialize_disliked_post_count, + 'email': self.serialize_email, + } + + def serialize_name(self) -> Any: + return self.user.name + + def serialize_creation_time(self) -> Any: + return self.user.creation_time + + def serialize_last_login_time(self) -> Any: + return self.user.last_login_time + + def serialize_version(self) -> Any: + return self.user.version + + def serialize_rank(self) -> Any: + return self.user.rank + + def serialize_avatar_style(self) -> Any: + return self.user.avatar_style + + def serialize_avatar_url(self) -> Any: + return get_avatar_url(self.user) + + def serialize_comment_count(self) -> Any: + return self.user.comment_count + + def serialize_uploaded_post_count(self) -> Any: + return self.user.post_count + + def serialize_favorite_post_count(self) -> Any: + return self.user.favorite_post_count + + def serialize_liked_post_count(self) -> Any: + return get_liked_post_count(self.user, self.auth_user) + + def serialize_disliked_post_count(self) -> Any: + return get_disliked_post_count(self.user, self.auth_user) + + def serialize_email(self) -> Any: + return get_email(self.user, self.auth_user, self.force_show_email) -def serialize_micro_user(user, auth_user): +def serialize_user( + user: Optional[model.User], + auth_user: model.User, + options: List[str]=[], + force_show_email: bool=False) -> Optional[rest.Response]: + if not user: + return None + return UserSerializer(user, auth_user, force_show_email).serialize(options) + + +def serialize_micro_user( + user: Optional[model.User], + auth_user: model.User) -> Optional[rest.Response]: return serialize_user( - user, - auth_user=auth_user, - options=['name', 'avatarUrl']) + user, auth_user=auth_user, options=['name', 'avatarUrl']) -def get_user_count(): - return db.session.query(db.User).count() +def get_user_count() -> int: + return db.session.query(model.User).count() -def try_get_user_by_name(name): +def try_get_user_by_name(name: str) -> Optional[model.User]: return db.session \ - .query(db.User) \ - .filter(func.lower(db.User.name) == func.lower(name)) \ + .query(model.User) \ + .filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \ .one_or_none() -def get_user_by_name(name): +def get_user_by_name(name: str) -> model.User: user = try_get_user_by_name(name) if not user: raise UserNotFoundError('User %r not found.' % name) return user -def try_get_user_by_name_or_email(name_or_email): +def try_get_user_by_name_or_email(name_or_email: str) -> Optional[model.User]: return ( db.session - .query(db.User) + .query(model.User) .filter( - (func.lower(db.User.name) == func.lower(name_or_email)) | - (func.lower(db.User.email) == func.lower(name_or_email))) + (sa.func.lower(model.User.name) == sa.func.lower(name_or_email)) | + (sa.func.lower(model.User.email) == sa.func.lower(name_or_email))) .one_or_none()) -def get_user_by_name_or_email(name_or_email): +def get_user_by_name_or_email(name_or_email: str) -> model.User: user = try_get_user_by_name_or_email(name_or_email) if not user: raise UserNotFoundError('User %r not found.' % name_or_email) return user -def create_user(name, password, email): - user = db.User() +def create_user(name: str, password: str, email: str) -> model.User: + user = model.User() update_user_name(user, name) update_user_password(user, password) update_user_email(user, email) if get_user_count() > 0: user.rank = util.flip(auth.RANK_MAP)[config.config['default_rank']] else: - user.rank = db.User.RANK_ADMINISTRATOR - user.creation_time = datetime.datetime.utcnow() - user.avatar_style = db.User.AVATAR_GRAVATAR + user.rank = model.User.RANK_ADMINISTRATOR + user.creation_time = datetime.utcnow() + user.avatar_style = model.User.AVATAR_GRAVATAR return user -def update_user_name(user, name): +def update_user_name(user: model.User, name: str) -> None: assert user if not name: raise InvalidUserNameError('Name cannot be empty.') - if util.value_exceeds_column_size(name, db.User.name): + if util.value_exceeds_column_size(name, model.User.name): raise InvalidUserNameError('User name is too long.') name = name.strip() name_regex = config.config['user_name_regex'] @@ -174,7 +233,7 @@ def update_user_name(user, name): user.name = name -def update_user_password(user, password): +def update_user_password(user: model.User, password: str) -> None: assert user if not password: raise InvalidPasswordError('Password cannot be empty.') @@ -186,20 +245,18 @@ def update_user_password(user, password): user.password_hash = auth.get_password_hash(user.password_salt, password) -def update_user_email(user, email): +def update_user_email(user: model.User, email: str) -> None: assert user - if email: - email = email.strip() - if not email: - email = None - if email and util.value_exceeds_column_size(email, db.User.email): + email = email.strip() + if util.value_exceeds_column_size(email, model.User.email): raise InvalidEmailError('Email is too long.') if not util.is_valid_email(email): raise InvalidEmailError('E-mail is invalid.') - user.email = email + user.email = email or None -def update_user_rank(user, rank, auth_user): +def update_user_rank( + user: model.User, rank: str, auth_user: model.User) -> None: assert user if not rank: raise InvalidRankError('Rank cannot be empty.') @@ -208,7 +265,7 @@ def update_user_rank(user, rank, auth_user): if not rank: raise InvalidRankError( 'Rank can be either of %r.' % all_ranks) - if rank in (db.User.RANK_ANONYMOUS, db.User.RANK_NOBODY): + if rank in (model.User.RANK_ANONYMOUS, model.User.RANK_NOBODY): raise InvalidRankError('Rank %r cannot be used.' % auth.RANK_MAP[rank]) if all_ranks.index(auth_user.rank) \ < all_ranks.index(rank) and get_user_count() > 0: @@ -216,7 +273,10 @@ def update_user_rank(user, rank, auth_user): user.rank = rank -def update_user_avatar(user, avatar_style, avatar_content=None): +def update_user_avatar( + user: model.User, + avatar_style: str, + avatar_content: Optional[bytes]=None) -> None: assert user if avatar_style == 'gravatar': user.avatar_style = user.AVATAR_GRAVATAR @@ -238,12 +298,12 @@ def update_user_avatar(user, avatar_style, avatar_content=None): avatar_style, ['gravatar', 'manual'])) -def bump_user_login_time(user): +def bump_user_login_time(user: model.User) -> None: assert user - user.last_login_time = datetime.datetime.utcnow() + user.last_login_time = datetime.utcnow() -def reset_user_password(user): +def reset_user_password(user: model.User) -> str: assert user password = auth.create_password() user.password_salt = auth.create_password() diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index 11caedd2..40d19d39 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -2,52 +2,39 @@ import os import hashlib import re import tempfile +from typing import ( + Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar) from datetime import datetime, timedelta from contextlib import contextmanager from szurubooru import errors -def snake_case_to_lower_camel_case(text): +T = TypeVar('T') + + +def snake_case_to_lower_camel_case(text: str) -> str: components = text.split('_') return components[0].lower() + \ ''.join(word[0].upper() + word[1:].lower() for word in components[1:]) -def snake_case_to_upper_train_case(text): +def snake_case_to_upper_train_case(text: str) -> str: return '-'.join( word[0].upper() + word[1:].lower() for word in text.split('_')) -def snake_case_to_lower_camel_case_keys(source): +def snake_case_to_lower_camel_case_keys( + source: Dict[str, Any]) -> Dict[str, Any]: target = {} for key, value in source.items(): target[snake_case_to_lower_camel_case(key)] = value return target -def get_serialization_options(ctx): - return ctx.get_param_as_list('fields', required=False, default=None) - - -def serialize_entity(entity, field_factories, options): - if not entity: - return None - if not options or len(options) == 0: - options = field_factories.keys() - ret = {} - for key in options: - if key not in field_factories: - raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % ( - key, list(sorted(field_factories.keys())))) - factory = field_factories[key] - ret[key] = factory() - return ret - - @contextmanager -def create_temp_file(**kwargs): - (handle, path) = tempfile.mkstemp(**kwargs) - os.close(handle) +def create_temp_file(**kwargs: Any) -> Generator: + (descriptor, path) = tempfile.mkstemp(**kwargs) + os.close(descriptor) try: with open(path, 'r+b') as handle: yield handle @@ -55,17 +42,15 @@ def create_temp_file(**kwargs): os.remove(path) -def unalias_dict(input_dict): - output_dict = {} - for key_list, value in input_dict.items(): - if isinstance(key_list, str): - key_list = [key_list] - for key in key_list: - output_dict[key] = value +def unalias_dict(source: List[Tuple[List[str], T]]) -> Dict[str, T]: + output_dict = {} # type: Dict[str, T] + for aliases, value in source: + for alias in aliases: + output_dict[alias] = value return output_dict -def get_md5(source): +def get_md5(source: Union[str, bytes]) -> str: if not isinstance(source, bytes): source = source.encode('utf-8') md5 = hashlib.md5() @@ -73,7 +58,7 @@ def get_md5(source): return md5.hexdigest() -def get_sha1(source): +def get_sha1(source: Union[str, bytes]) -> str: if not isinstance(source, bytes): source = source.encode('utf-8') sha1 = hashlib.sha1() @@ -81,24 +66,25 @@ def get_sha1(source): return sha1.hexdigest() -def flip(source): +def flip(source: Dict[Any, Any]) -> Dict[Any, Any]: return {v: k for k, v in source.items()} -def is_valid_email(email): +def is_valid_email(email: Optional[str]) -> bool: ''' Return whether given email address is valid or empty. ''' - return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) + return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) is not None class dotdict(dict): # pylint: disable=invalid-name ''' dot.notation access to dictionary attributes. ''' - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: return self.get(attr) + __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ -def parse_time_range(value): +def parse_time_range(value: str) -> Tuple[datetime, datetime]: ''' Return tuple containing min/max time for given text representation. ''' one_day = timedelta(days=1) one_second = timedelta(seconds=1) @@ -146,9 +132,9 @@ def parse_time_range(value): raise errors.ValidationError('Invalid date format: %r.' % value) -def icase_unique(source): - target = [] - target_low = [] +def icase_unique(source: List[str]) -> List[str]: + target = [] # type: List[str] + target_low = [] # type: List[str] for source_item in source: if source_item.lower() not in target_low: target.append(source_item) @@ -156,7 +142,7 @@ def icase_unique(source): return target -def value_exceeds_column_size(value, column): +def value_exceeds_column_size(value: Optional[str], column: Any) -> bool: if not value: return False max_length = column.property.columns[0].type.length @@ -165,6 +151,6 @@ def value_exceeds_column_size(value, column): return len(value) > max_length -def chunks(source_list, part_size): +def chunks(source_list: List[Any], part_size: int) -> Generator: for i in range(0, len(source_list), part_size): yield source_list[i:i + part_size] diff --git a/server/szurubooru/func/versions.py b/server/szurubooru/func/versions.py index ee84407b..459b0256 100644 --- a/server/szurubooru/func/versions.py +++ b/server/szurubooru/func/versions.py @@ -1,8 +1,11 @@ -from szurubooru import errors +from szurubooru import errors, rest, model -def verify_version(entity, context, field_name='version'): - actual_version = context.get_param_as_int(field_name, required=True) +def verify_version( + entity: model.Base, + context: rest.Context, + field_name: str='version') -> None: + actual_version = context.get_param_as_int(field_name) expected_version = entity.version if actual_version != expected_version: raise errors.IntegrityError( @@ -10,5 +13,5 @@ def verify_version(entity, context, field_name='version'): 'Please try again.') -def bump_version(entity): +def bump_version(entity: model.Base) -> None: entity.version = entity.version + 1 diff --git a/server/szurubooru/middleware/authenticator.py b/server/szurubooru/middleware/authenticator.py index f6b7853f..2c5ac087 100644 --- a/server/szurubooru/middleware/authenticator.py +++ b/server/szurubooru/middleware/authenticator.py @@ -1,11 +1,11 @@ import base64 -from szurubooru import db, errors +from typing import Optional +from szurubooru import db, model, errors, rest from szurubooru.func import auth, users -from szurubooru.rest import middleware from szurubooru.rest.errors import HttpBadRequest -def _authenticate(username, password): +def _authenticate(username: str, password: str) -> model.User: ''' Try to authenticate user. Throw AuthError for invalid users. ''' user = users.get_user_by_name(username) if not auth.is_valid_password(user, password): @@ -13,16 +13,9 @@ def _authenticate(username, password): return user -def _create_anonymous_user(): - user = db.User() - user.name = None - user.rank = 'anonymous' - return user - - -def _get_user(ctx): +def _get_user(ctx: rest.Context) -> Optional[model.User]: if not ctx.has_header('Authorization'): - return _create_anonymous_user() + return None try: auth_type, credentials = ctx.get_header('Authorization').split(' ', 1) @@ -41,10 +34,12 @@ def _get_user(ctx): msg.format(ctx.get_header('Authorization'), str(err))) -@middleware.pre_hook -def process_request(ctx): +@rest.middleware.pre_hook +def process_request(ctx: rest.Context) -> None: ''' Bind the user to request. Update last login time if needed. ''' - ctx.user = _get_user(ctx) - if ctx.get_param_as_bool('bump-login') and ctx.user.user_id: + auth_user = _get_user(ctx) + if auth_user: + ctx.user = auth_user + if ctx.get_param_as_bool('bump-login', default=False) and ctx.user.user_id: users.bump_user_login_time(ctx.user) ctx.session.commit() diff --git a/server/szurubooru/middleware/cache_purger.py b/server/szurubooru/middleware/cache_purger.py index e26b3bae..d83fb845 100644 --- a/server/szurubooru/middleware/cache_purger.py +++ b/server/szurubooru/middleware/cache_purger.py @@ -1,8 +1,9 @@ +from szurubooru import rest from szurubooru.func import cache from szurubooru.rest import middleware @middleware.pre_hook -def process_request(ctx): +def process_request(ctx: rest.Context) -> None: if ctx.method != 'GET': cache.purge() diff --git a/server/szurubooru/middleware/request_logger.py b/server/szurubooru/middleware/request_logger.py index 47b43ab5..54e40e4a 100644 --- a/server/szurubooru/middleware/request_logger.py +++ b/server/szurubooru/middleware/request_logger.py @@ -1,5 +1,5 @@ import logging -from szurubooru import db +from szurubooru import db, rest from szurubooru.rest import middleware @@ -7,12 +7,12 @@ logger = logging.getLogger(__name__) @middleware.pre_hook -def process_request(_ctx): +def process_request(_ctx: rest.Context) -> None: db.reset_query_count() @middleware.post_hook -def process_response(ctx): +def process_response(ctx: rest.Context) -> None: logger.info( '%s %s (user=%s, queries=%d)', ctx.method, diff --git a/server/szurubooru/migrations/env.py b/server/szurubooru/migrations/env.py index 1359ab8a..a4257d48 100644 --- a/server/szurubooru/migrations/env.py +++ b/server/szurubooru/migrations/env.py @@ -2,7 +2,7 @@ import os import sys import alembic -import sqlalchemy +import sqlalchemy as sa import logging.config # make szurubooru module importable @@ -48,7 +48,7 @@ def run_migrations_online(): In this scenario we need to create an Engine and associate a connection with the context. ''' - connectable = sqlalchemy.engine_from_config( + connectable = sa.engine_from_config( alembic_config.get_section(alembic_config.config_ini_section), prefix='sqlalchemy.', poolclass=sqlalchemy.pool.NullPool) diff --git a/server/szurubooru/model/__init__.py b/server/szurubooru/model/__init__.py new file mode 100644 index 00000000..ad2231c2 --- /dev/null +++ b/server/szurubooru/model/__init__.py @@ -0,0 +1,15 @@ +from szurubooru.model.base import Base +from szurubooru.model.user import User +from szurubooru.model.tag_category import TagCategory +from szurubooru.model.tag import Tag, TagName, TagSuggestion, TagImplication +from szurubooru.model.post import ( + Post, + PostTag, + PostRelation, + PostFavorite, + PostScore, + PostNote, + PostFeature) +from szurubooru.model.comment import Comment, CommentScore +from szurubooru.model.snapshot import Snapshot +import szurubooru.model.util diff --git a/server/szurubooru/db/base.py b/server/szurubooru/model/base.py similarity index 100% rename from server/szurubooru/db/base.py rename to server/szurubooru/model/base.py diff --git a/server/szurubooru/db/comment.py b/server/szurubooru/model/comment.py similarity index 84% rename from server/szurubooru/db/comment.py rename to server/szurubooru/model/comment.py index bf325859..55c1596b 100644 --- a/server/szurubooru/db/comment.py +++ b/server/szurubooru/model/comment.py @@ -1,7 +1,8 @@ from sqlalchemy import Column, Integer, DateTime, UnicodeText, ForeignKey from sqlalchemy.orm import relationship, backref from sqlalchemy.sql.expression import func -from szurubooru.db.base import Base +from szurubooru.db import get_session +from szurubooru.model.base import Base class CommentScore(Base): @@ -48,12 +49,12 @@ class Comment(Base): 'CommentScore', cascade='all, delete-orphan', lazy='joined') @property - def score(self): - from szurubooru.db import session - return session \ - .query(func.sum(CommentScore.score)) \ - .filter(CommentScore.comment_id == self.comment_id) \ - .one()[0] or 0 + def score(self) -> int: + return ( + get_session() + .query(func.sum(CommentScore.score)) + .filter(CommentScore.comment_id == self.comment_id) + .one()[0] or 0) __mapper_args__ = { 'version_id_col': version, diff --git a/server/szurubooru/db/post.py b/server/szurubooru/model/post.py similarity index 95% rename from server/szurubooru/db/post.py rename to server/szurubooru/model/post.py index f0c9f91f..23f52b57 100644 --- a/server/szurubooru/db/post.py +++ b/server/szurubooru/model/post.py @@ -3,8 +3,8 @@ from sqlalchemy import ( Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey) from sqlalchemy.orm import ( relationship, column_property, object_session, backref) -from szurubooru.db.base import Base -from szurubooru.db.comment import Comment +from szurubooru.model.base import Base +from szurubooru.model.comment import Comment class PostFeature(Base): @@ -17,10 +17,9 @@ class PostFeature(Base): 'user_id', Integer, ForeignKey('user.id'), nullable=False, index=True) time = Column('time', DateTime, nullable=False) - post = relationship('Post') + post = relationship('Post') # type: Post user = relationship( - 'User', - backref=backref('post_features', cascade='all, delete-orphan')) + 'User', backref=backref('post_features', cascade='all, delete-orphan')) class PostScore(Base): @@ -104,7 +103,7 @@ class PostRelation(Base): nullable=False, index=True) - def __init__(self, parent_id, child_id): + def __init__(self, parent_id: int, child_id: int) -> None: self.parent_id = parent_id self.child_id = child_id @@ -127,7 +126,7 @@ class PostTag(Base): nullable=False, index=True) - def __init__(self, post_id, tag_id): + def __init__(self, post_id: int, tag_id: int) -> None: self.post_id = post_id self.tag_id = tag_id @@ -197,7 +196,7 @@ class Post(Base): canvas_area = column_property(canvas_width * canvas_height) @property - def is_featured(self): + def is_featured(self) -> bool: featured_post = object_session(self) \ .query(PostFeature) \ .order_by(PostFeature.time.desc()) \ diff --git a/server/szurubooru/db/snapshot.py b/server/szurubooru/model/snapshot.py similarity index 96% rename from server/szurubooru/db/snapshot.py rename to server/szurubooru/model/snapshot.py index 4b211f61..beb3bb25 100644 --- a/server/szurubooru/db/snapshot.py +++ b/server/szurubooru/model/snapshot.py @@ -1,7 +1,7 @@ from sqlalchemy.orm import relationship from sqlalchemy import ( Column, Integer, DateTime, Unicode, PickleType, ForeignKey) -from szurubooru.db.base import Base +from szurubooru.model.base import Base class Snapshot(Base): diff --git a/server/szurubooru/db/tag.py b/server/szurubooru/model/tag.py similarity index 93% rename from server/szurubooru/db/tag.py rename to server/szurubooru/model/tag.py index 10813eb9..1bce3ffa 100644 --- a/server/szurubooru/db/tag.py +++ b/server/szurubooru/model/tag.py @@ -2,8 +2,8 @@ from sqlalchemy import ( Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey) from sqlalchemy.orm import relationship, column_property from sqlalchemy.sql.expression import func, select -from szurubooru.db.base import Base -from szurubooru.db.post import PostTag +from szurubooru.model.base import Base +from szurubooru.model.post import PostTag class TagSuggestion(Base): @@ -24,7 +24,7 @@ class TagSuggestion(Base): primary_key=True, index=True) - def __init__(self, parent_id, child_id): + def __init__(self, parent_id: int, child_id: int) -> None: self.parent_id = parent_id self.child_id = child_id @@ -47,7 +47,7 @@ class TagImplication(Base): primary_key=True, index=True) - def __init__(self, parent_id, child_id): + def __init__(self, parent_id: int, child_id: int) -> None: self.parent_id = parent_id self.child_id = child_id @@ -61,7 +61,7 @@ class TagName(Base): name = Column('name', Unicode(64), nullable=False, unique=True) order = Column('ord', Integer, nullable=False, index=True) - def __init__(self, name, order): + def __init__(self, name: str, order: int) -> None: self.name = name self.order = order diff --git a/server/szurubooru/db/tag_category.py b/server/szurubooru/model/tag_category.py similarity index 84% rename from server/szurubooru/db/tag_category.py rename to server/szurubooru/model/tag_category.py index 907910ba..001f9653 100644 --- a/server/szurubooru/db/tag_category.py +++ b/server/szurubooru/model/tag_category.py @@ -1,8 +1,9 @@ +from typing import Optional from sqlalchemy import Column, Integer, Unicode, Boolean, table from sqlalchemy.orm import column_property from sqlalchemy.sql.expression import func, select -from szurubooru.db.base import Base -from szurubooru.db.tag import Tag +from szurubooru.model.base import Base +from szurubooru.model.tag import Tag class TagCategory(Base): @@ -14,7 +15,7 @@ class TagCategory(Base): color = Column('color', Unicode(32), nullable=False, default='#000000') default = Column('default', Boolean, nullable=False, default=False) - def __init__(self, name=None): + def __init__(self, name: Optional[str]=None) -> None: self.name = name tag_count = column_property( diff --git a/server/szurubooru/db/user.py b/server/szurubooru/model/user.py similarity index 50% rename from server/szurubooru/db/user.py rename to server/szurubooru/model/user.py index 4f4f9961..dd7c0629 100644 --- a/server/szurubooru/db/user.py +++ b/server/szurubooru/model/user.py @@ -1,9 +1,7 @@ -from sqlalchemy import Column, Integer, Unicode, DateTime -from sqlalchemy.orm import relationship -from sqlalchemy.sql.expression import func -from szurubooru.db.base import Base -from szurubooru.db.post import Post, PostScore, PostFavorite -from szurubooru.db.comment import Comment +import sqlalchemy as sa +from szurubooru.model.base import Base +from szurubooru.model.post import Post, PostScore, PostFavorite +from szurubooru.model.comment import Comment class User(Base): @@ -20,63 +18,64 @@ class User(Base): RANK_ADMINISTRATOR = 'administrator' RANK_NOBODY = 'nobody' # unattainable, used for privileges - user_id = Column('id', Integer, primary_key=True) - creation_time = Column('creation_time', DateTime, nullable=False) - last_login_time = Column('last_login_time', DateTime) - version = Column('version', Integer, default=1, nullable=False) - name = Column('name', Unicode(50), nullable=False, unique=True) - password_hash = Column('password_hash', Unicode(64), nullable=False) - password_salt = Column('password_salt', Unicode(32)) - email = Column('email', Unicode(64), nullable=True) - rank = Column('rank', Unicode(32), nullable=False) - avatar_style = Column( - 'avatar_style', Unicode(32), nullable=False, default=AVATAR_GRAVATAR) + user_id = sa.Column('id', sa.Integer, primary_key=True) + creation_time = sa.Column('creation_time', sa.DateTime, nullable=False) + last_login_time = sa.Column('last_login_time', sa.DateTime) + version = sa.Column('version', sa.Integer, default=1, nullable=False) + name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True) + password_hash = sa.Column('password_hash', sa.Unicode(64), nullable=False) + password_salt = sa.Column('password_salt', sa.Unicode(32)) + email = sa.Column('email', sa.Unicode(64), nullable=True) + rank = sa.Column('rank', sa.Unicode(32), nullable=False) + avatar_style = sa.Column( + 'avatar_style', sa.Unicode(32), nullable=False, + default=AVATAR_GRAVATAR) - comments = relationship('Comment') + comments = sa.orm.relationship('Comment') @property - def post_count(self): + def post_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(Post.user_id == self.user_id) .one()[0] or 0) @property - def comment_count(self): + def comment_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(Comment.user_id == self.user_id) .one()[0] or 0) @property - def favorite_post_count(self): + def favorite_post_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(PostFavorite.user_id == self.user_id) .one()[0] or 0) @property - def liked_post_count(self): + def liked_post_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(PostScore.user_id == self.user_id) .filter(PostScore.score == 1) .one()[0] or 0) @property - def disliked_post_count(self): + def disliked_post_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(PostScore.user_id == self.user_id) .filter(PostScore.score == -1) .one()[0] or 0) diff --git a/server/szurubooru/model/util.py b/server/szurubooru/model/util.py new file mode 100644 index 00000000..e82539f1 --- /dev/null +++ b/server/szurubooru/model/util.py @@ -0,0 +1,42 @@ +from typing import Tuple, Any, Dict, Callable, Union, Optional +import sqlalchemy as sa +from szurubooru.model.base import Base +from szurubooru.model.user import User + + +def get_resource_info(entity: Base) -> Tuple[Any, Any, Union[str, int]]: + serializers = { + 'tag': lambda tag: tag.first_name, + 'tag_category': lambda category: category.name, + 'comment': lambda comment: comment.comment_id, + 'post': lambda post: post.post_id, + } # type: Dict[str, Callable[[Base], Any]] + + resource_type = entity.__table__.name + assert resource_type in serializers + + primary_key = sa.inspection.inspect(entity).identity # type: Any + assert primary_key is not None + assert len(primary_key) == 1 + + resource_name = serializers[resource_type](entity) # type: Union[str, int] + assert resource_name + + resource_pkey = primary_key[0] # type: Any + assert resource_pkey + + return (resource_type, resource_pkey, resource_name) + + +def get_aux_entity( + session: Any, + get_table_info: Callable[[Base], Tuple[Base, Callable[[Base], Any]]], + entity: Base, + user: User) -> Optional[Base]: + table, get_column = get_table_info(entity) + return ( + session + .query(table) + .filter(get_column(table) == get_column(entity)) + .filter(table.user_id == user.user_id) + .one_or_none()) diff --git a/server/szurubooru/rest/__init__.py b/server/szurubooru/rest/__init__.py index ac9958a5..14a3e305 100644 --- a/server/szurubooru/rest/__init__.py +++ b/server/szurubooru/rest/__init__.py @@ -1,2 +1,2 @@ from szurubooru.rest.app import application -from szurubooru.rest.context import Context +from szurubooru.rest.context import Context, Response diff --git a/server/szurubooru/rest/app.py b/server/szurubooru/rest/app.py index 1bbf8dce..b29110e7 100644 --- a/server/szurubooru/rest/app.py +++ b/server/szurubooru/rest/app.py @@ -2,13 +2,14 @@ import urllib.parse import cgi import json import re +from typing import Dict, Any, Callable, Tuple from datetime import datetime from szurubooru import db from szurubooru.func import util from szurubooru.rest import errors, middleware, routes, context -def _json_serializer(obj): +def _json_serializer(obj: Any) -> str: ''' JSON serializer for objects not serializable by default JSON code ''' if isinstance(obj, datetime): serial = obj.isoformat('T') + 'Z' @@ -16,12 +17,12 @@ def _json_serializer(obj): raise TypeError('Type not serializable') -def _dump_json(obj): +def _dump_json(obj: Any) -> str: return json.dumps(obj, default=_json_serializer, indent=2) -def _get_headers(env): - headers = {} +def _get_headers(env: Dict[str, Any]) -> Dict[str, str]: + headers = {} # type: Dict[str, str] for key, value in env.items(): if key.startswith('HTTP_'): key = util.snake_case_to_upper_train_case(key[5:]) @@ -29,7 +30,7 @@ def _get_headers(env): return headers -def _create_context(env): +def _create_context(env: Dict[str, Any]) -> context.Context: method = env['REQUEST_METHOD'] path = '/' + env['PATH_INFO'].lstrip('/') headers = _get_headers(env) @@ -64,7 +65,9 @@ def _create_context(env): return context.Context(method, path, headers, params, files) -def application(env, start_response): +def application( + env: Dict[str, Any], + start_response: Callable[[str, Any], Any]) -> Tuple[bytes]: try: ctx = _create_context(env) if 'application/json' not in ctx.get_header('Accept'): @@ -106,9 +109,9 @@ def application(env, start_response): return (_dump_json(response).encode('utf-8'),) except Exception as ex: - for exception_type, handler in errors.error_handlers.items(): + for exception_type, ex_handler in errors.error_handlers.items(): if isinstance(ex, exception_type): - handler(ex) + ex_handler(ex) raise except errors.BaseHttpError as ex: diff --git a/server/szurubooru/rest/context.py b/server/szurubooru/rest/context.py index ae26f38b..bb33bfab 100644 --- a/server/szurubooru/rest/context.py +++ b/server/szurubooru/rest/context.py @@ -1,111 +1,158 @@ -from szurubooru import errors +from typing import Any, Union, List, Dict, Optional, cast +from szurubooru import model, errors from szurubooru.func import net, file_uploads -def _lower_first(source): - return source[0].lower() + source[1:] - - -def _param_wrapper(func): - def wrapper(self, name, required=False, default=None, **kwargs): - # pylint: disable=protected-access - if name in self._params: - value = self._params[name] - try: - value = func(self, value, **kwargs) - except errors.InvalidParameterError as ex: - raise errors.InvalidParameterError( - 'Parameter %r is invalid: %s' % ( - name, _lower_first(str(ex)))) - return value - if not required: - return default - raise errors.MissingRequiredParameterError( - 'Required parameter %r is missing.' % name) - return wrapper +MISSING = object() +Request = Dict[str, Any] +Response = Optional[Dict[str, Any]] class Context: - def __init__(self, method, url, headers=None, params=None, files=None): + def __init__( + self, + method: str, + url: str, + headers: Dict[str, str]=None, + params: Request=None, + files: Dict[str, bytes]=None) -> None: self.method = method self.url = url self._headers = headers or {} self._params = params or {} self._files = files or {} - # provided by middleware - # self.session = None - # self.user = None + self.user = model.User() + self.user.name = None + self.user.rank = 'anonymous' - def has_header(self, name): + self.session = None # type: Any + + def has_header(self, name: str) -> bool: return name in self._headers - def get_header(self, name): - return self._headers.get(name, None) + def get_header(self, name: str) -> str: + return self._headers.get(name, '') - def has_file(self, name, allow_tokens=True): + def has_file(self, name: str, allow_tokens: bool=True) -> bool: return ( name in self._files or name + 'Url' in self._params or (allow_tokens and name + 'Token' in self._params)) - def get_file(self, name, required=False, allow_tokens=True): - ret = None - if name in self._files: - ret = self._files[name] - elif name + 'Url' in self._params: - ret = net.download(self._params[name + 'Url']) - elif allow_tokens and name + 'Token' in self._params: + def get_file( + self, + name: str, + default: Union[object, bytes]=MISSING, + allow_tokens: bool=True) -> bytes: + if name in self._files and self._files[name]: + return self._files[name] + + if name + 'Url' in self._params: + return net.download(self._params[name + 'Url']) + + if allow_tokens and name + 'Token' in self._params: ret = file_uploads.get(self._params[name + 'Token']) - if required and not ret: + if ret: + return ret + elif default is not MISSING: raise errors.MissingOrExpiredRequiredFileError( 'Required file %r is missing or has expired.' % name) - if required and not ret: - raise errors.MissingRequiredFileError( - 'Required file %r is missing.' % name) - return ret - def has_param(self, name): + if default is not MISSING: + return cast(bytes, default) + raise errors.MissingRequiredFileError( + 'Required file %r is missing.' % name) + + def has_param(self, name: str) -> bool: return name in self._params - @_param_wrapper - def get_param_as_list(self, value): - if not isinstance(value, list): + def get_param_as_list( + self, + name: str, + default: Union[object, List[Any]]=MISSING) -> List[Any]: + if name not in self._params: + if default is not MISSING: + return cast(List[Any], default) + raise errors.MissingRequiredParameterError( + 'Required parameter %r is missing.' % name) + value = self._params[name] + if type(value) is str: if ',' in value: return value.split(',') return [value] - return value + if type(value) is list: + return value + raise errors.InvalidParameterError( + 'Parameter %r must be a list.' % name) - @_param_wrapper - def get_param_as_string(self, value): - if isinstance(value, list): - try: - value = ','.join(value) - except TypeError: - raise errors.InvalidParameterError('Expected simple string.') - return value + def get_param_as_string( + self, + name: str, + default: Union[object, str]=MISSING) -> str: + if name not in self._params: + if default is not MISSING: + return cast(str, default) + raise errors.MissingRequiredParameterError( + 'Required parameter %r is missing.' % name) + value = self._params[name] + try: + if value is None: + return '' + if type(value) is list: + return ','.join(value) + if type(value) is int or type(value) is float: + return str(value) + if type(value) is str: + return value + except TypeError: + pass + raise errors.InvalidParameterError( + 'Parameter %r must be a string value.' % name) - @_param_wrapper - def get_param_as_int(self, value, min=None, max=None): + def get_param_as_int( + self, + name: str, + default: Union[object, int]=MISSING, + min: Optional[int]=None, + max: Optional[int]=None) -> int: + if name not in self._params: + if default is not MISSING: + return cast(int, default) + raise errors.MissingRequiredParameterError( + 'Required parameter %r is missing.' % name) + value = self._params[name] try: value = int(value) + if min is not None and value < min: + raise errors.InvalidParameterError( + 'Parameter %r must be at least %r.' % (name, min)) + if max is not None and value > max: + raise errors.InvalidParameterError( + 'Parameter %r may not exceed %r.' % (name, max)) + return value except (ValueError, TypeError): - raise errors.InvalidParameterError( - 'The value must be an integer.') - if min is not None and value < min: - raise errors.InvalidParameterError( - 'The value must be at least %r.' % min) - if max is not None and value > max: - raise errors.InvalidParameterError( - 'The value may not exceed %r.' % max) - return value + pass + raise errors.InvalidParameterError( + 'Parameter %r must be an integer value.' % name) - @_param_wrapper - def get_param_as_bool(self, value): - value = str(value).lower() + def get_param_as_bool( + self, + name: str, + default: Union[object, bool]=MISSING) -> bool: + if name not in self._params: + if default is not MISSING: + return cast(bool, default) + raise errors.MissingRequiredParameterError( + 'Required parameter %r is missing.' % name) + value = self._params[name] + try: + value = str(value).lower() + except TypeError: + pass if value in ['1', 'y', 'yes', 'yeah', 'yep', 'yup', 't', 'true']: return True if value in ['0', 'n', 'no', 'nope', 'f', 'false']: return False raise errors.InvalidParameterError( - 'The value must be a boolean value.') + 'Parameter %r must be a boolean value.' % name) diff --git a/server/szurubooru/rest/errors.py b/server/szurubooru/rest/errors.py index b0f5b882..6854e7d3 100644 --- a/server/szurubooru/rest/errors.py +++ b/server/szurubooru/rest/errors.py @@ -1,11 +1,19 @@ +from typing import Callable, Type, Dict + + error_handlers = {} # pylint: disable=invalid-name class BaseHttpError(RuntimeError): - code = None - reason = None + code = -1 + reason = '' - def __init__(self, name, description, title=None, extra_fields=None): + def __init__( + self, + name: str, + description: str, + title: str=None, + extra_fields: Dict[str, str]=None) -> None: super().__init__() # error name for programmers self.name = name @@ -52,5 +60,7 @@ class HttpInternalServerError(BaseHttpError): reason = 'Internal Server Error' -def handle(exception_type, handler): +def handle( + exception_type: Type[Exception], + handler: Callable[[Exception], None]) -> None: error_handlers[exception_type] = handler diff --git a/server/szurubooru/rest/middleware.py b/server/szurubooru/rest/middleware.py index 7cf07296..05d9495e 100644 --- a/server/szurubooru/rest/middleware.py +++ b/server/szurubooru/rest/middleware.py @@ -1,11 +1,15 @@ +from typing import Callable +from szurubooru.rest.context import Context + + # pylint: disable=invalid-name -pre_hooks = [] -post_hooks = [] +pre_hooks = [] # type: List[Callable[[Context], None]] +post_hooks = [] # type: List[Callable[[Context], None]] -def pre_hook(handler): +def pre_hook(handler: Callable) -> None: pre_hooks.append(handler) -def post_hook(handler): +def post_hook(handler: Callable) -> None: post_hooks.insert(0, handler) diff --git a/server/szurubooru/rest/routes.py b/server/szurubooru/rest/routes.py index ffa95f56..c0b6bea3 100644 --- a/server/szurubooru/rest/routes.py +++ b/server/szurubooru/rest/routes.py @@ -1,32 +1,36 @@ +from typing import Callable, Dict, Any from collections import defaultdict +from szurubooru.rest.context import Context, Response -routes = defaultdict(dict) # pylint: disable=invalid-name +# pylint: disable=invalid-name +RouteHandler = Callable[[Context, Dict[str, str]], Response] +routes = defaultdict(dict) # type: Dict[str, Dict[str, RouteHandler]] -def get(url): - def wrapper(handler): +def get(url: str) -> Callable[[RouteHandler], RouteHandler]: + def wrapper(handler: RouteHandler) -> RouteHandler: routes[url]['GET'] = handler return handler return wrapper -def put(url): - def wrapper(handler): +def put(url: str) -> Callable[[RouteHandler], RouteHandler]: + def wrapper(handler: RouteHandler) -> RouteHandler: routes[url]['PUT'] = handler return handler return wrapper -def post(url): - def wrapper(handler): +def post(url: str) -> Callable[[RouteHandler], RouteHandler]: + def wrapper(handler: RouteHandler) -> RouteHandler: routes[url]['POST'] = handler return handler return wrapper -def delete(url): - def wrapper(handler): +def delete(url: str) -> Callable[[RouteHandler], RouteHandler]: + def wrapper(handler: RouteHandler) -> RouteHandler: routes[url]['DELETE'] = handler return handler return wrapper diff --git a/server/szurubooru/search/configs/base_search_config.py b/server/szurubooru/search/configs/base_search_config.py index adc50d30..0cb814d4 100644 --- a/server/szurubooru/search/configs/base_search_config.py +++ b/server/szurubooru/search/configs/base_search_config.py @@ -1,38 +1,47 @@ -from szurubooru.search import tokens +from typing import Optional, Tuple, Dict, Callable +from szurubooru.search import tokens, criteria +from szurubooru.search.query import SearchQuery +from szurubooru.search.typing import SaColumn, SaQuery + +Filter = Callable[[SaQuery, Optional[criteria.BaseCriterion], bool], SaQuery] class BaseSearchConfig: + SORT_NONE = tokens.SortToken.SORT_NONE SORT_ASC = tokens.SortToken.SORT_ASC SORT_DESC = tokens.SortToken.SORT_DESC - def on_search_query_parsed(self, search_query): + def on_search_query_parsed(self, search_query: SearchQuery) -> None: pass - def create_filter_query(self, _disable_eager_loads): + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: raise NotImplementedError() - def create_count_query(self, disable_eager_loads): + def create_count_query(self, disable_eager_loads: bool) -> SaQuery: raise NotImplementedError() - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() + def finalize_query(self, query: SaQuery) -> SaQuery: + return query + @property - def id_column(self): + def id_column(self) -> SaColumn: return None @property - def anonymous_filter(self): + def anonymous_filter(self) -> Optional[Filter]: return None @property - def special_filters(self): + def special_filters(self) -> Dict[str, Filter]: return {} @property - def named_filters(self): + def named_filters(self) -> Dict[str, Filter]: return {} @property - def sort_columns(self): + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: return {} diff --git a/server/szurubooru/search/configs/comment_search_config.py b/server/szurubooru/search/configs/comment_search_config.py index 9b2515e8..8b154460 100644 --- a/server/szurubooru/search/configs/comment_search_config.py +++ b/server/szurubooru/search/configs/comment_search_config.py @@ -1,59 +1,62 @@ -from sqlalchemy.sql.expression import func -from szurubooru import db +from typing import Tuple, Dict +import sqlalchemy as sa +from szurubooru import db, model +from szurubooru.search.typing import SaColumn, SaQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) class CommentSearchConfig(BaseSearchConfig): - def create_filter_query(self, _disable_eager_loads): - return db.session.query(db.Comment).join(db.User) + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Comment).join(model.User) - def create_count_query(self, disable_eager_loads): + def create_count_query(self, disable_eager_loads: bool) -> SaQuery: return self.create_filter_query(disable_eager_loads) - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() - def finalize_query(self, query): - return query.order_by(db.Comment.creation_time.desc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.Comment.creation_time.desc()) @property - def anonymous_filter(self): - return search_util.create_str_filter(db.Comment.text) + def anonymous_filter(self) -> SaQuery: + return search_util.create_str_filter(model.Comment.text) @property - def named_filters(self): + def named_filters(self) -> Dict[str, Filter]: return { - 'id': search_util.create_num_filter(db.Comment.comment_id), - 'post': search_util.create_num_filter(db.Comment.post_id), - 'user': search_util.create_str_filter(db.User.name), - 'author': search_util.create_str_filter(db.User.name), - 'text': search_util.create_str_filter(db.Comment.text), + 'id': search_util.create_num_filter(model.Comment.comment_id), + 'post': search_util.create_num_filter(model.Comment.post_id), + 'user': search_util.create_str_filter(model.User.name), + 'author': search_util.create_str_filter(model.User.name), + 'text': search_util.create_str_filter(model.Comment.text), 'creation-date': - search_util.create_date_filter(db.Comment.creation_time), + search_util.create_date_filter(model.Comment.creation_time), 'creation-time': - search_util.create_date_filter(db.Comment.creation_time), + search_util.create_date_filter(model.Comment.creation_time), 'last-edit-date': - search_util.create_date_filter(db.Comment.last_edit_time), + search_util.create_date_filter(model.Comment.last_edit_time), 'last-edit-time': - search_util.create_date_filter(db.Comment.last_edit_time), + search_util.create_date_filter(model.Comment.last_edit_time), 'edit-date': - search_util.create_date_filter(db.Comment.last_edit_time), + search_util.create_date_filter(model.Comment.last_edit_time), 'edit-time': - search_util.create_date_filter(db.Comment.last_edit_time), + search_util.create_date_filter(model.Comment.last_edit_time), } @property - def sort_columns(self): + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: return { - 'random': (func.random(), None), - 'user': (db.User.name, self.SORT_ASC), - 'author': (db.User.name, self.SORT_ASC), - 'post': (db.Comment.post_id, self.SORT_DESC), - 'creation-date': (db.Comment.creation_time, self.SORT_DESC), - 'creation-time': (db.Comment.creation_time, self.SORT_DESC), - 'last-edit-date': (db.Comment.last_edit_time, self.SORT_DESC), - 'last-edit-time': (db.Comment.last_edit_time, self.SORT_DESC), - 'edit-date': (db.Comment.last_edit_time, self.SORT_DESC), - 'edit-time': (db.Comment.last_edit_time, self.SORT_DESC), + 'random': (sa.sql.expression.func.random(), self.SORT_NONE), + 'user': (model.User.name, self.SORT_ASC), + 'author': (model.User.name, self.SORT_ASC), + 'post': (model.Comment.post_id, self.SORT_DESC), + 'creation-date': (model.Comment.creation_time, self.SORT_DESC), + 'creation-time': (model.Comment.creation_time, self.SORT_DESC), + 'last-edit-date': (model.Comment.last_edit_time, self.SORT_DESC), + 'last-edit-time': (model.Comment.last_edit_time, self.SORT_DESC), + 'edit-date': (model.Comment.last_edit_time, self.SORT_DESC), + 'edit-time': (model.Comment.last_edit_time, self.SORT_DESC), } diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index 7005cd7c..cda1b1ac 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -1,13 +1,16 @@ -from sqlalchemy.orm import subqueryload, lazyload, defer, aliased -from sqlalchemy.sql.expression import func -from szurubooru import db, errors +from typing import Any, Optional, Tuple, Dict +import sqlalchemy as sa +from szurubooru import db, model, errors from szurubooru.func import util from szurubooru.search import criteria, tokens +from szurubooru.search.typing import SaColumn, SaQuery +from szurubooru.search.query import SearchQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) -def _enum_transformer(available_values, value): +def _enum_transformer(available_values: Dict[str, Any], value: str) -> str: try: return available_values[value.lower()] except KeyError: @@ -16,71 +19,82 @@ def _enum_transformer(available_values, value): value, list(sorted(available_values.keys())))) -def _type_transformer(value): +def _type_transformer(value: str) -> str: available_values = { - 'image': db.Post.TYPE_IMAGE, - 'animation': db.Post.TYPE_ANIMATION, - 'animated': db.Post.TYPE_ANIMATION, - 'anim': db.Post.TYPE_ANIMATION, - 'gif': db.Post.TYPE_ANIMATION, - 'video': db.Post.TYPE_VIDEO, - 'webm': db.Post.TYPE_VIDEO, - 'flash': db.Post.TYPE_FLASH, - 'swf': db.Post.TYPE_FLASH, + 'image': model.Post.TYPE_IMAGE, + 'animation': model.Post.TYPE_ANIMATION, + 'animated': model.Post.TYPE_ANIMATION, + 'anim': model.Post.TYPE_ANIMATION, + 'gif': model.Post.TYPE_ANIMATION, + 'video': model.Post.TYPE_VIDEO, + 'webm': model.Post.TYPE_VIDEO, + 'flash': model.Post.TYPE_FLASH, + 'swf': model.Post.TYPE_FLASH, } return _enum_transformer(available_values, value) -def _safety_transformer(value): +def _safety_transformer(value: str) -> str: available_values = { - 'safe': db.Post.SAFETY_SAFE, - 'sketchy': db.Post.SAFETY_SKETCHY, - 'questionable': db.Post.SAFETY_SKETCHY, - 'unsafe': db.Post.SAFETY_UNSAFE, + 'safe': model.Post.SAFETY_SAFE, + 'sketchy': model.Post.SAFETY_SKETCHY, + 'questionable': model.Post.SAFETY_SKETCHY, + 'unsafe': model.Post.SAFETY_UNSAFE, } return _enum_transformer(available_values, value) -def _create_score_filter(score): - def wrapper(query, criterion, negated): +def _create_score_filter(score: int) -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion if not getattr(criterion, 'internal', False): raise errors.SearchError( 'Votes cannot be seen publicly. Did you mean %r?' % 'special:liked') - user_alias = aliased(db.User) - score_alias = aliased(db.PostScore) + user_alias = sa.orm.aliased(model.User) + score_alias = sa.orm.aliased(model.PostScore) expr = score_alias.score == score expr = expr & search_util.apply_str_criterion_to_column( user_alias.name, criterion) if negated: expr = ~expr ret = query \ - .join(score_alias, score_alias.post_id == db.Post.post_id) \ + .join(score_alias, score_alias.post_id == model.Post.post_id) \ .join(user_alias, user_alias.user_id == score_alias.user_id) \ .filter(expr) return ret return wrapper -def _create_user_filter(): - def wrapper(query, criterion, negated): +def _create_user_filter() -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion if isinstance(criterion, criteria.PlainCriterion) \ and not criterion.value: # pylint: disable=singleton-comparison - expr = db.Post.user_id == None + expr = model.Post.user_id == None if negated: expr = ~expr return query.filter(expr) return search_util.create_subquery_filter( - db.Post.user_id, - db.User.user_id, - db.User.name, + model.Post.user_id, + model.User.user_id, + model.User.name, search_util.create_str_filter)(query, criterion, negated) return wrapper class PostSearchConfig(BaseSearchConfig): - def on_search_query_parsed(self, search_query): + def __init__(self) -> None: + self.user = None # type: Optional[model.User] + + def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery: new_special_tokens = [] for token in search_query.special_tokens: if token.value in ('fav', 'liked', 'disliked'): @@ -91,7 +105,7 @@ class PostSearchConfig(BaseSearchConfig): criterion = criteria.PlainCriterion( original_text=self.user.name, value=self.user.name) - criterion.internal = True + setattr(criterion, 'internal', True) search_query.named_tokens.append( tokens.NamedToken( name=token.value, @@ -101,160 +115,324 @@ class PostSearchConfig(BaseSearchConfig): new_special_tokens.append(token) search_query.special_tokens = new_special_tokens - def create_around_query(self): - return db.session.query(db.Post).options(lazyload('*')) + def create_around_query(self) -> SaQuery: + return db.session.query(model.Post).options(sa.orm.lazyload('*')) - def create_filter_query(self, disable_eager_loads): - strategy = lazyload if disable_eager_loads else subqueryload - return db.session.query(db.Post) \ + def create_filter_query(self, disable_eager_loads: bool) -> SaQuery: + strategy = ( + sa.orm.lazyload + if disable_eager_loads + else sa.orm.subqueryload) + return db.session.query(model.Post) \ .options( - lazyload('*'), + sa.orm.lazyload('*'), # use config optimized for official client - # defer(db.Post.score), - # defer(db.Post.favorite_count), - # defer(db.Post.comment_count), - defer(db.Post.last_favorite_time), - defer(db.Post.feature_count), - defer(db.Post.last_feature_time), - defer(db.Post.last_comment_creation_time), - defer(db.Post.last_comment_edit_time), - defer(db.Post.note_count), - defer(db.Post.tag_count), - strategy(db.Post.tags).subqueryload(db.Tag.names), - strategy(db.Post.tags).defer(db.Tag.post_count), - strategy(db.Post.tags).lazyload(db.Tag.implications), - strategy(db.Post.tags).lazyload(db.Tag.suggestions)) + # sa.orm.defer(model.Post.score), + # sa.orm.defer(model.Post.favorite_count), + # sa.orm.defer(model.Post.comment_count), + sa.orm.defer(model.Post.last_favorite_time), + sa.orm.defer(model.Post.feature_count), + sa.orm.defer(model.Post.last_feature_time), + sa.orm.defer(model.Post.last_comment_creation_time), + sa.orm.defer(model.Post.last_comment_edit_time), + sa.orm.defer(model.Post.note_count), + sa.orm.defer(model.Post.tag_count), + strategy(model.Post.tags).subqueryload(model.Tag.names), + strategy(model.Post.tags).defer(model.Tag.post_count), + strategy(model.Post.tags).lazyload(model.Tag.implications), + strategy(model.Post.tags).lazyload(model.Tag.suggestions)) - def create_count_query(self, _disable_eager_loads): - return db.session.query(db.Post) + def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Post) - def finalize_query(self, query): - return query.order_by(db.Post.post_id.desc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.Post.post_id.desc()) @property - def id_column(self): - return db.Post.post_id + def id_column(self) -> SaColumn: + return model.Post.post_id @property - def anonymous_filter(self): + def anonymous_filter(self) -> Optional[Filter]: return search_util.create_subquery_filter( - db.Post.post_id, - db.PostTag.post_id, - db.TagName.name, + model.Post.post_id, + model.PostTag.post_id, + model.TagName.name, search_util.create_str_filter, - lambda subquery: subquery.join(db.Tag).join(db.TagName)) + lambda subquery: subquery.join(model.Tag).join(model.TagName)) @property - def named_filters(self): - return util.unalias_dict({ - 'id': search_util.create_num_filter(db.Post.post_id), - 'tag': search_util.create_subquery_filter( - db.Post.post_id, - db.PostTag.post_id, - db.TagName.name, - search_util.create_str_filter, - lambda subquery: subquery.join(db.Tag).join(db.TagName)), - 'score': search_util.create_num_filter(db.Post.score), - ('uploader', 'upload', 'submit'): - _create_user_filter(), - 'comment': search_util.create_subquery_filter( - db.Post.post_id, - db.Comment.post_id, - db.User.name, - search_util.create_str_filter, - lambda subquery: subquery.join(db.User)), - 'fav': search_util.create_subquery_filter( - db.Post.post_id, - db.PostFavorite.post_id, - db.User.name, - search_util.create_str_filter, - lambda subquery: subquery.join(db.User)), - 'liked': _create_score_filter(1), - 'disliked': _create_score_filter(-1), - 'tag-count': search_util.create_num_filter(db.Post.tag_count), - 'comment-count': - search_util.create_num_filter(db.Post.comment_count), - 'fav-count': - search_util.create_num_filter(db.Post.favorite_count), - 'note-count': search_util.create_num_filter(db.Post.note_count), - 'relation-count': - search_util.create_num_filter(db.Post.relation_count), - 'feature-count': - search_util.create_num_filter(db.Post.feature_count), - 'type': + def named_filters(self) -> Dict[str, Filter]: + return util.unalias_dict([ + ( + ['id'], + search_util.create_num_filter(model.Post.post_id) + ), + + ( + ['tag'], + search_util.create_subquery_filter( + model.Post.post_id, + model.PostTag.post_id, + model.TagName.name, + search_util.create_str_filter, + lambda subquery: + subquery.join(model.Tag).join(model.TagName)) + ), + + ( + ['score'], + search_util.create_num_filter(model.Post.score) + ), + + ( + ['uploader', 'upload', 'submit'], + _create_user_filter() + ), + + ( + ['comment'], + search_util.create_subquery_filter( + model.Post.post_id, + model.Comment.post_id, + model.User.name, + search_util.create_str_filter, + lambda subquery: subquery.join(model.User)) + ), + + ( + ['fav'], + search_util.create_subquery_filter( + model.Post.post_id, + model.PostFavorite.post_id, + model.User.name, + search_util.create_str_filter, + lambda subquery: subquery.join(model.User)) + ), + + ( + ['liked'], + _create_score_filter(1) + ), + ( + ['disliked'], + _create_score_filter(-1) + ), + + ( + ['tag-count'], + search_util.create_num_filter(model.Post.tag_count) + ), + + ( + ['comment-count'], + search_util.create_num_filter(model.Post.comment_count) + ), + + ( + ['fav-count'], + search_util.create_num_filter(model.Post.favorite_count) + ), + + ( + ['note-count'], + search_util.create_num_filter(model.Post.note_count) + ), + + ( + ['relation-count'], + search_util.create_num_filter(model.Post.relation_count) + ), + + ( + ['feature-count'], + search_util.create_num_filter(model.Post.feature_count) + ), + + ( + ['type'], search_util.create_str_filter( - db.Post.type, _type_transformer), - 'content-checksum': search_util.create_str_filter( - db.Post.checksum), - 'file-size': search_util.create_num_filter(db.Post.file_size), - ('image-width', 'width'): - search_util.create_num_filter(db.Post.canvas_width), - ('image-height', 'height'): - search_util.create_num_filter(db.Post.canvas_height), - ('image-area', 'area'): - search_util.create_num_filter(db.Post.canvas_area), - ('creation-date', 'creation-time', 'date', 'time'): - search_util.create_date_filter(db.Post.creation_time), - ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): - search_util.create_date_filter(db.Post.last_edit_time), - ('comment-date', 'comment-time'): + model.Post.type, _type_transformer) + ), + + ( + ['content-checksum'], + search_util.create_str_filter(model.Post.checksum) + ), + + ( + ['file-size'], + search_util.create_num_filter(model.Post.file_size) + ), + + ( + ['image-width', 'width'], + search_util.create_num_filter(model.Post.canvas_width) + ), + + ( + ['image-height', 'height'], + search_util.create_num_filter(model.Post.canvas_height) + ), + + ( + ['image-area', 'area'], + search_util.create_num_filter(model.Post.canvas_area) + ), + + ( + ['creation-date', 'creation-time', 'date', 'time'], + search_util.create_date_filter(model.Post.creation_time) + ), + + ( + ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], + search_util.create_date_filter(model.Post.last_edit_time) + ), + + ( + ['comment-date', 'comment-time'], search_util.create_date_filter( - db.Post.last_comment_creation_time), - ('fav-date', 'fav-time'): - search_util.create_date_filter(db.Post.last_favorite_time), - ('feature-date', 'feature-time'): - search_util.create_date_filter(db.Post.last_feature_time), - ('safety', 'rating'): + model.Post.last_comment_creation_time) + ), + + ( + ['fav-date', 'fav-time'], + search_util.create_date_filter(model.Post.last_favorite_time) + ), + + ( + ['feature-date', 'feature-time'], + search_util.create_date_filter(model.Post.last_feature_time) + ), + + ( + ['safety', 'rating'], search_util.create_str_filter( - db.Post.safety, _safety_transformer), - }) + model.Post.safety, _safety_transformer) + ), + ]) @property - def sort_columns(self): - return util.unalias_dict({ - 'random': (func.random(), None), - 'id': (db.Post.post_id, self.SORT_DESC), - 'score': (db.Post.score, self.SORT_DESC), - 'tag-count': (db.Post.tag_count, self.SORT_DESC), - 'comment-count': (db.Post.comment_count, self.SORT_DESC), - 'fav-count': (db.Post.favorite_count, self.SORT_DESC), - 'note-count': (db.Post.note_count, self.SORT_DESC), - 'relation-count': (db.Post.relation_count, self.SORT_DESC), - 'feature-count': (db.Post.feature_count, self.SORT_DESC), - 'file-size': (db.Post.file_size, self.SORT_DESC), - ('image-width', 'width'): - (db.Post.canvas_width, self.SORT_DESC), - ('image-height', 'height'): - (db.Post.canvas_height, self.SORT_DESC), - ('image-area', 'area'): - (db.Post.canvas_area, self.SORT_DESC), - ('creation-date', 'creation-time', 'date', 'time'): - (db.Post.creation_time, self.SORT_DESC), - ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): - (db.Post.last_edit_time, self.SORT_DESC), - ('comment-date', 'comment-time'): - (db.Post.last_comment_creation_time, self.SORT_DESC), - ('fav-date', 'fav-time'): - (db.Post.last_favorite_time, self.SORT_DESC), - ('feature-date', 'feature-time'): - (db.Post.last_feature_time, self.SORT_DESC), - }) + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: + return util.unalias_dict([ + ( + ['random'], + (sa.sql.expression.func.random(), self.SORT_NONE) + ), + + ( + ['id'], + (model.Post.post_id, self.SORT_DESC) + ), + + ( + ['score'], + (model.Post.score, self.SORT_DESC) + ), + + ( + ['tag-count'], + (model.Post.tag_count, self.SORT_DESC) + ), + + ( + ['comment-count'], + (model.Post.comment_count, self.SORT_DESC) + ), + + ( + ['fav-count'], + (model.Post.favorite_count, self.SORT_DESC) + ), + + ( + ['note-count'], + (model.Post.note_count, self.SORT_DESC) + ), + + ( + ['relation-count'], + (model.Post.relation_count, self.SORT_DESC) + ), + + ( + ['feature-count'], + (model.Post.feature_count, self.SORT_DESC) + ), + + ( + ['file-size'], + (model.Post.file_size, self.SORT_DESC) + ), + + ( + ['image-width', 'width'], + (model.Post.canvas_width, self.SORT_DESC) + ), + + ( + ['image-height', 'height'], + (model.Post.canvas_height, self.SORT_DESC) + ), + + ( + ['image-area', 'area'], + (model.Post.canvas_area, self.SORT_DESC) + ), + + ( + ['creation-date', 'creation-time', 'date', 'time'], + (model.Post.creation_time, self.SORT_DESC) + ), + + ( + ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], + (model.Post.last_edit_time, self.SORT_DESC) + ), + + ( + ['comment-date', 'comment-time'], + (model.Post.last_comment_creation_time, self.SORT_DESC) + ), + + ( + ['fav-date', 'fav-time'], + (model.Post.last_favorite_time, self.SORT_DESC) + ), + + ( + ['feature-date', 'feature-time'], + (model.Post.last_feature_time, self.SORT_DESC) + ), + ]) @property - def special_filters(self): + def special_filters(self) -> Dict[str, Filter]: return { - # handled by parsed - 'fav': None, - 'liked': None, - 'disliked': None, + # handled by parser + 'fav': self.noop_filter, + 'liked': self.noop_filter, + 'disliked': self.noop_filter, 'tumbleweed': self.tumbleweed_filter, } - def tumbleweed_filter(self, query, negated): - expr = \ - (db.Post.comment_count == 0) \ - & (db.Post.favorite_count == 0) \ - & (db.Post.score == 0) + def noop_filter( + self, + query: SaQuery, + _criterion: Optional[criteria.BaseCriterion], + _negated: bool) -> SaQuery: + return query + + def tumbleweed_filter( + self, + query: SaQuery, + _criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + expr = ( + (model.Post.comment_count == 0) + & (model.Post.favorite_count == 0) + & (model.Post.score == 0)) if negated: expr = ~expr return query.filter(expr) diff --git a/server/szurubooru/search/configs/snapshot_search_config.py b/server/szurubooru/search/configs/snapshot_search_config.py index 4ea7280a..0fdb69d0 100644 --- a/server/szurubooru/search/configs/snapshot_search_config.py +++ b/server/szurubooru/search/configs/snapshot_search_config.py @@ -1,28 +1,37 @@ -from szurubooru import db +from typing import Dict +from szurubooru import db, model +from szurubooru.search.typing import SaQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) class SnapshotSearchConfig(BaseSearchConfig): - def create_filter_query(self, _disable_eager_loads): - return db.session.query(db.Snapshot) + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Snapshot) - def create_count_query(self, _disable_eager_loads): - return db.session.query(db.Snapshot) + def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Snapshot) - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() - def finalize_query(self, query): - return query.order_by(db.Snapshot.creation_time.desc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.Snapshot.creation_time.desc()) @property - def named_filters(self): + def named_filters(self) -> Dict[str, Filter]: return { - 'type': search_util.create_str_filter(db.Snapshot.resource_type), - 'id': search_util.create_str_filter(db.Snapshot.resource_name), - 'date': search_util.create_date_filter(db.Snapshot.creation_time), - 'time': search_util.create_date_filter(db.Snapshot.creation_time), - 'operation': search_util.create_str_filter(db.Snapshot.operation), - 'user': search_util.create_str_filter(db.User.name), + 'type': + search_util.create_str_filter(model.Snapshot.resource_type), + 'id': + search_util.create_str_filter(model.Snapshot.resource_name), + 'date': + search_util.create_date_filter(model.Snapshot.creation_time), + 'time': + search_util.create_date_filter(model.Snapshot.creation_time), + 'operation': + search_util.create_str_filter(model.Snapshot.operation), + 'user': + search_util.create_str_filter(model.User.name), } diff --git a/server/szurubooru/search/configs/tag_search_config.py b/server/szurubooru/search/configs/tag_search_config.py index 4595d82f..6dba5b02 100644 --- a/server/szurubooru/search/configs/tag_search_config.py +++ b/server/szurubooru/search/configs/tag_search_config.py @@ -1,79 +1,134 @@ -from sqlalchemy.orm import subqueryload, lazyload, defer -from sqlalchemy.sql.expression import func -from szurubooru import db +from typing import Tuple, Dict +import sqlalchemy as sa +from szurubooru import db, model from szurubooru.func import util +from szurubooru.search.typing import SaColumn, SaQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) class TagSearchConfig(BaseSearchConfig): - def create_filter_query(self, _disable_eager_loads): - strategy = lazyload if _disable_eager_loads else subqueryload - return db.session.query(db.Tag) \ - .join(db.TagCategory) \ + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: + strategy = ( + sa.orm.lazyload + if _disable_eager_loads + else sa.orm.subqueryload) + return db.session.query(model.Tag) \ + .join(model.TagCategory) \ .options( - defer(db.Tag.first_name), - defer(db.Tag.suggestion_count), - defer(db.Tag.implication_count), - defer(db.Tag.post_count), - strategy(db.Tag.names), - strategy(db.Tag.suggestions).joinedload(db.Tag.names), - strategy(db.Tag.implications).joinedload(db.Tag.names)) + sa.orm.defer(model.Tag.first_name), + sa.orm.defer(model.Tag.suggestion_count), + sa.orm.defer(model.Tag.implication_count), + sa.orm.defer(model.Tag.post_count), + strategy(model.Tag.names), + strategy(model.Tag.suggestions).joinedload(model.Tag.names), + strategy(model.Tag.implications).joinedload(model.Tag.names)) - def create_count_query(self, _disable_eager_loads): - return db.session.query(db.Tag) + def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Tag) - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() - def finalize_query(self, query): - return query.order_by(db.Tag.first_name.asc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.Tag.first_name.asc()) @property - def anonymous_filter(self): + def anonymous_filter(self) -> Filter: return search_util.create_subquery_filter( - db.Tag.tag_id, - db.TagName.tag_id, - db.TagName.name, + model.Tag.tag_id, + model.TagName.tag_id, + model.TagName.name, search_util.create_str_filter) @property - def named_filters(self): - return util.unalias_dict({ - 'name': search_util.create_subquery_filter( - db.Tag.tag_id, - db.TagName.tag_id, - db.TagName.name, - search_util.create_str_filter), - 'category': search_util.create_subquery_filter( - db.Tag.category_id, - db.TagCategory.tag_category_id, - db.TagCategory.name, - search_util.create_str_filter), - ('creation-date', 'creation-time'): - search_util.create_date_filter(db.Tag.creation_time), - ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): - search_util.create_date_filter(db.Tag.last_edit_time), - ('usage-count', 'post-count', 'usages'): - search_util.create_num_filter(db.Tag.post_count), - 'suggestion-count': - search_util.create_num_filter(db.Tag.suggestion_count), - 'implication-count': - search_util.create_num_filter(db.Tag.implication_count), - }) + def named_filters(self) -> Dict[str, Filter]: + return util.unalias_dict([ + ( + ['name'], + search_util.create_subquery_filter( + model.Tag.tag_id, + model.TagName.tag_id, + model.TagName.name, + search_util.create_str_filter) + ), + + ( + ['category'], + search_util.create_subquery_filter( + model.Tag.category_id, + model.TagCategory.tag_category_id, + model.TagCategory.name, + search_util.create_str_filter) + ), + + ( + ['creation-date', 'creation-time'], + search_util.create_date_filter(model.Tag.creation_time) + ), + + ( + ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], + search_util.create_date_filter(model.Tag.last_edit_time) + ), + + ( + ['usage-count', 'post-count', 'usages'], + search_util.create_num_filter(model.Tag.post_count) + ), + + ( + ['suggestion-count'], + search_util.create_num_filter(model.Tag.suggestion_count) + ), + + ( + ['implication-count'], + search_util.create_num_filter(model.Tag.implication_count) + ), + ]) @property - def sort_columns(self): - return util.unalias_dict({ - 'random': (func.random(), None), - 'name': (db.Tag.first_name, self.SORT_ASC), - 'category': (db.TagCategory.name, self.SORT_ASC), - ('creation-date', 'creation-time'): - (db.Tag.creation_time, self.SORT_DESC), - ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): - (db.Tag.last_edit_time, self.SORT_DESC), - ('usage-count', 'post-count', 'usages'): - (db.Tag.post_count, self.SORT_DESC), - 'suggestion-count': (db.Tag.suggestion_count, self.SORT_DESC), - 'implication-count': (db.Tag.implication_count, self.SORT_DESC), - }) + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: + return util.unalias_dict([ + ( + ['random'], + (sa.sql.expression.func.random(), self.SORT_NONE) + ), + + ( + ['name'], + (model.Tag.first_name, self.SORT_ASC) + ), + + ( + ['category'], + (model.TagCategory.name, self.SORT_ASC) + ), + + ( + ['creation-date', 'creation-time'], + (model.Tag.creation_time, self.SORT_DESC) + ), + + ( + ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], + (model.Tag.last_edit_time, self.SORT_DESC) + ), + + ( + ['usage-count', 'post-count', 'usages'], + (model.Tag.post_count, self.SORT_DESC) + ), + + ( + ['suggestion-count'], + (model.Tag.suggestion_count, self.SORT_DESC) + ), + + ( + ['implication-count'], + (model.Tag.implication_count, self.SORT_DESC) + ), + ]) diff --git a/server/szurubooru/search/configs/user_search_config.py b/server/szurubooru/search/configs/user_search_config.py index c7e727e6..64534009 100644 --- a/server/szurubooru/search/configs/user_search_config.py +++ b/server/szurubooru/search/configs/user_search_config.py @@ -1,53 +1,57 @@ -from sqlalchemy.sql.expression import func -from szurubooru import db +from typing import Tuple, Dict +import sqlalchemy as sa +from szurubooru import db, model +from szurubooru.search.typing import SaColumn, SaQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) class UserSearchConfig(BaseSearchConfig): - def create_filter_query(self, _disable_eager_loads): - return db.session.query(db.User) + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.User) - def create_count_query(self, _disable_eager_loads): - return db.session.query(db.User) + def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.User) - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() - def finalize_query(self, query): - return query.order_by(db.User.name.asc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.User.name.asc()) @property - def anonymous_filter(self): - return search_util.create_str_filter(db.User.name) + def anonymous_filter(self) -> Filter: + return search_util.create_str_filter(model.User.name) @property - def named_filters(self): + def named_filters(self) -> Dict[str, Filter]: return { - 'name': search_util.create_str_filter(db.User.name), + 'name': + search_util.create_str_filter(model.User.name), 'creation-date': - search_util.create_date_filter(db.User.creation_time), + search_util.create_date_filter(model.User.creation_time), 'creation-time': - search_util.create_date_filter(db.User.creation_time), + search_util.create_date_filter(model.User.creation_time), 'last-login-date': - search_util.create_date_filter(db.User.last_login_time), + search_util.create_date_filter(model.User.last_login_time), 'last-login-time': - search_util.create_date_filter(db.User.last_login_time), + search_util.create_date_filter(model.User.last_login_time), 'login-date': - search_util.create_date_filter(db.User.last_login_time), + search_util.create_date_filter(model.User.last_login_time), 'login-time': - search_util.create_date_filter(db.User.last_login_time), + search_util.create_date_filter(model.User.last_login_time), } @property - def sort_columns(self): + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: return { - 'random': (func.random(), None), - 'name': (db.User.name, self.SORT_ASC), - 'creation-date': (db.User.creation_time, self.SORT_DESC), - 'creation-time': (db.User.creation_time, self.SORT_DESC), - 'last-login-date': (db.User.last_login_time, self.SORT_DESC), - 'last-login-time': (db.User.last_login_time, self.SORT_DESC), - 'login-date': (db.User.last_login_time, self.SORT_DESC), - 'login-time': (db.User.last_login_time, self.SORT_DESC), + 'random': (sa.sql.expression.func.random(), self.SORT_NONE), + 'name': (model.User.name, self.SORT_ASC), + 'creation-date': (model.User.creation_time, self.SORT_DESC), + 'creation-time': (model.User.creation_time, self.SORT_DESC), + 'last-login-date': (model.User.last_login_time, self.SORT_DESC), + 'last-login-time': (model.User.last_login_time, self.SORT_DESC), + 'login-date': (model.User.last_login_time, self.SORT_DESC), + 'login-time': (model.User.last_login_time, self.SORT_DESC), } diff --git a/server/szurubooru/search/configs/util.py b/server/szurubooru/search/configs/util.py index 2eaaf8d7..086f3921 100644 --- a/server/szurubooru/search/configs/util.py +++ b/server/szurubooru/search/configs/util.py @@ -1,10 +1,13 @@ -import sqlalchemy +from typing import Any, Optional, Callable +import sqlalchemy as sa from szurubooru import db, errors from szurubooru.func import util from szurubooru.search import criteria +from szurubooru.search.typing import SaColumn, SaQuery +from szurubooru.search.configs.base_search_config import Filter -def wildcard_transformer(value): +def wildcard_transformer(value: str) -> str: return ( value .replace('\\', '\\\\') @@ -13,24 +16,21 @@ def wildcard_transformer(value): .replace('*', '%')) -def apply_num_criterion_to_column(column, criterion): - ''' - Decorate SQLAlchemy filter on given column using supplied criterion. - ''' +def apply_num_criterion_to_column( + column: Any, criterion: criteria.BaseCriterion) -> Any: try: if isinstance(criterion, criteria.PlainCriterion): expr = column == int(criterion.value) elif isinstance(criterion, criteria.ArrayCriterion): expr = column.in_(int(value) for value in criterion.values) elif isinstance(criterion, criteria.RangedCriterion): - assert criterion.min_value != '' \ - or criterion.max_value != '' - if criterion.min_value != '' and criterion.max_value != '': + assert criterion.min_value or criterion.max_value + if criterion.min_value and criterion.max_value: expr = column.between( int(criterion.min_value), int(criterion.max_value)) - elif criterion.min_value != '': + elif criterion.min_value: expr = column >= int(criterion.min_value) - elif criterion.max_value != '': + elif criterion.max_value: expr = column <= int(criterion.max_value) else: assert False @@ -40,10 +40,13 @@ def apply_num_criterion_to_column(column, criterion): return expr -def create_num_filter(column): - def wrapper(query, criterion, negated): - expr = apply_num_criterion_to_column( - column, criterion) +def create_num_filter(column: Any) -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion + expr = apply_num_criterion_to_column(column, criterion) if negated: expr = ~expr return query.filter(expr) @@ -51,14 +54,13 @@ def create_num_filter(column): def apply_str_criterion_to_column( - column, criterion, transformer=wildcard_transformer): - ''' - Decorate SQLAlchemy filter on given column using supplied criterion. - ''' + column: SaColumn, + criterion: criteria.BaseCriterion, + transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery: if isinstance(criterion, criteria.PlainCriterion): expr = column.ilike(transformer(criterion.value)) elif isinstance(criterion, criteria.ArrayCriterion): - expr = sqlalchemy.sql.false() + expr = sa.sql.false() for value in criterion.values: expr = expr | column.ilike(transformer(value)) elif isinstance(criterion, criteria.RangedCriterion): @@ -68,8 +70,15 @@ def apply_str_criterion_to_column( return expr -def create_str_filter(column, transformer=wildcard_transformer): - def wrapper(query, criterion, negated): +def create_str_filter( + column: SaColumn, + transformer: Callable[[str], str]=wildcard_transformer +) -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion expr = apply_str_criterion_to_column( column, criterion, transformer) if negated: @@ -78,16 +87,13 @@ def create_str_filter(column, transformer=wildcard_transformer): return wrapper -def apply_date_criterion_to_column(column, criterion): - ''' - Decorate SQLAlchemy filter on given column using supplied criterion. - Parse the datetime inside the criterion. - ''' +def apply_date_criterion_to_column( + column: SaQuery, criterion: criteria.BaseCriterion) -> SaQuery: if isinstance(criterion, criteria.PlainCriterion): min_date, max_date = util.parse_time_range(criterion.value) expr = column.between(min_date, max_date) elif isinstance(criterion, criteria.ArrayCriterion): - expr = sqlalchemy.sql.false() + expr = sa.sql.false() for value in criterion.values: min_date, max_date = util.parse_time_range(value) expr = expr | column.between(min_date, max_date) @@ -108,10 +114,13 @@ def apply_date_criterion_to_column(column, criterion): return expr -def create_date_filter(column): - def wrapper(query, criterion, negated): - expr = apply_date_criterion_to_column( - column, criterion) +def create_date_filter(column: SaColumn) -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion + expr = apply_date_criterion_to_column(column, criterion) if negated: expr = ~expr return query.filter(expr) @@ -119,18 +128,22 @@ def create_date_filter(column): def create_subquery_filter( - left_id_column, - right_id_column, - filter_column, - filter_factory, - subquery_decorator=None): + left_id_column: SaColumn, + right_id_column: SaColumn, + filter_column: SaColumn, + filter_factory: SaColumn, + subquery_decorator: Callable[[SaQuery], None]=None) -> Filter: filter_func = filter_factory(filter_column) - def wrapper(query, criterion, negated): + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion subquery = db.session.query(right_id_column.label('foreign_id')) if subquery_decorator: subquery = subquery_decorator(subquery) - subquery = subquery.options(sqlalchemy.orm.lazyload('*')) + subquery = subquery.options(sa.orm.lazyload('*')) subquery = filter_func(subquery, criterion, False) subquery = subquery.subquery('t') expression = left_id_column.in_(subquery) diff --git a/server/szurubooru/search/criteria.py b/server/szurubooru/search/criteria.py index 9d4dc664..7b1fee31 100644 --- a/server/szurubooru/search/criteria.py +++ b/server/szurubooru/search/criteria.py @@ -1,34 +1,42 @@ -class _BaseCriterion: - def __init__(self, original_text): +from typing import Optional, List, Callable +from szurubooru.search.typing import SaQuery + + +class BaseCriterion: + def __init__(self, original_text: str) -> None: self.original_text = original_text - def __repr__(self): + def __repr__(self) -> str: return self.original_text -class RangedCriterion(_BaseCriterion): - def __init__(self, original_text, min_value, max_value): +class RangedCriterion(BaseCriterion): + def __init__( + self, + original_text: str, + min_value: Optional[str], + max_value: Optional[str]) -> None: super().__init__(original_text) self.min_value = min_value self.max_value = max_value - def __hash__(self): + def __hash__(self) -> int: return hash(('range', self.min_value, self.max_value)) -class PlainCriterion(_BaseCriterion): - def __init__(self, original_text, value): +class PlainCriterion(BaseCriterion): + def __init__(self, original_text: str, value: str) -> None: super().__init__(original_text) self.value = value - def __hash__(self): + def __hash__(self) -> int: return hash(self.value) -class ArrayCriterion(_BaseCriterion): - def __init__(self, original_text, values): +class ArrayCriterion(BaseCriterion): + def __init__(self, original_text: str, values: List[str]) -> None: super().__init__(original_text) self.values = values - def __hash__(self): + def __hash__(self) -> int: return hash(tuple(['array'] + self.values)) diff --git a/server/szurubooru/search/executor.py b/server/szurubooru/search/executor.py index d9adc940..2819593e 100644 --- a/server/szurubooru/search/executor.py +++ b/server/szurubooru/search/executor.py @@ -1,14 +1,18 @@ -import sqlalchemy -from szurubooru import db, errors +from typing import Union, Tuple, List, Dict, Callable +import sqlalchemy as sa +from szurubooru import db, model, errors, rest from szurubooru.func import cache from szurubooru.search import tokens, parser +from szurubooru.search.typing import SaQuery +from szurubooru.search.query import SearchQuery +from szurubooru.search.configs.base_search_config import BaseSearchConfig -def _format_dict_keys(source): +def _format_dict_keys(source: Dict) -> List[str]: return list(sorted(source.keys())) -def _get_order(order, default_order): +def _get_order(order: str, default_order: str) -> Union[bool, str]: if order == tokens.SortToken.SORT_DEFAULT: return default_order or tokens.SortToken.SORT_ASC if order == tokens.SortToken.SORT_NEGATED_DEFAULT: @@ -26,50 +30,57 @@ class Executor: delegates sqlalchemy filter decoration to SearchConfig instances. ''' - def __init__(self, search_config): + def __init__(self, search_config: BaseSearchConfig) -> None: self.config = search_config self.parser = parser.Parser() - def get_around(self, query_text, entity_id): + def get_around( + self, + query_text: str, + entity_id: int) -> Tuple[model.Base, model.Base]: search_query = self.parser.parse(query_text) self.config.on_search_query_parsed(search_query) filter_query = ( self.config .create_around_query() - .options(sqlalchemy.orm.lazyload('*'))) + .options(sa.orm.lazyload('*'))) filter_query = self._prepare_db_query( filter_query, search_query, False) prev_filter_query = ( filter_query .filter(self.config.id_column > entity_id) .order_by(None) - .order_by(sqlalchemy.func.abs( - self.config.id_column - entity_id).asc()) + .order_by(sa.func.abs(self.config.id_column - entity_id).asc()) .limit(1)) next_filter_query = ( filter_query .filter(self.config.id_column < entity_id) .order_by(None) - .order_by(sqlalchemy.func.abs( - self.config.id_column - entity_id).asc()) + .order_by(sa.func.abs(self.config.id_column - entity_id).asc()) .limit(1)) - return [ + return ( prev_filter_query.one_or_none(), - next_filter_query.one_or_none()] + next_filter_query.one_or_none()) - def get_around_and_serialize(self, ctx, entity_id, serializer): - entities = self.get_around(ctx.get_param_as_string('query'), entity_id) + def get_around_and_serialize( + self, + ctx: rest.Context, + entity_id: int, + serializer: Callable[[model.Base], rest.Response] + ) -> rest.Response: + entities = self.get_around( + ctx.get_param_as_string('query', default=''), entity_id) return { 'prev': serializer(entities[0]), 'next': serializer(entities[1]), } - def execute(self, query_text, page, page_size): - ''' - Parse input and return tuple containing total record count and filtered - entities. - ''' - + def execute( + self, + query_text: str, + page: int, + page_size: int + ) -> Tuple[int, List[model.Base]]: search_query = self.parser.parse(query_text) self.config.on_search_query_parsed(search_query) @@ -83,7 +94,7 @@ class Executor: return cache.get(key) filter_query = self.config.create_filter_query(disable_eager_loads) - filter_query = filter_query.options(sqlalchemy.orm.lazyload('*')) + filter_query = filter_query.options(sa.orm.lazyload('*')) filter_query = self._prepare_db_query(filter_query, search_query, True) entities = filter_query \ .offset(max(page - 1, 0) * page_size) \ @@ -91,11 +102,11 @@ class Executor: .all() count_query = self.config.create_count_query(disable_eager_loads) - count_query = count_query.options(sqlalchemy.orm.lazyload('*')) + count_query = count_query.options(sa.orm.lazyload('*')) count_query = self._prepare_db_query(count_query, search_query, False) count_statement = count_query \ .statement \ - .with_only_columns([sqlalchemy.func.count()]) \ + .with_only_columns([sa.func.count()]) \ .order_by(None) count = db.session.execute(count_statement).scalar() @@ -103,8 +114,12 @@ class Executor: cache.put(key, ret) return ret - def execute_and_serialize(self, ctx, serializer): - query = ctx.get_param_as_string('query') + def execute_and_serialize( + self, + ctx: rest.Context, + serializer: Callable[[model.Base], rest.Response] + ) -> rest.Response: + query = ctx.get_param_as_string('query', default='') page = ctx.get_param_as_int('page', default=1, min=1) page_size = ctx.get_param_as_int( 'pageSize', default=100, min=1, max=100) @@ -117,48 +132,51 @@ class Executor: 'results': [serializer(entity) for entity in entities], } - def _prepare_db_query(self, db_query, search_query, use_sort): - ''' Parse input and return SQLAlchemy query. ''' - - for token in search_query.anonymous_tokens: + def _prepare_db_query( + self, + db_query: SaQuery, + search_query: SearchQuery, + use_sort: bool) -> SaQuery: + for anon_token in search_query.anonymous_tokens: if not self.config.anonymous_filter: raise errors.SearchError( 'Anonymous tokens are not valid in this context.') db_query = self.config.anonymous_filter( - db_query, token.criterion, token.negated) + db_query, anon_token.criterion, anon_token.negated) - for token in search_query.named_tokens: - if token.name not in self.config.named_filters: + for named_token in search_query.named_tokens: + if named_token.name not in self.config.named_filters: raise errors.SearchError( 'Unknown named token: %r. Available named tokens: %r.' % ( - token.name, + named_token.name, _format_dict_keys(self.config.named_filters))) - db_query = self.config.named_filters[token.name]( - db_query, token.criterion, token.negated) + db_query = self.config.named_filters[named_token.name]( + db_query, named_token.criterion, named_token.negated) - for token in search_query.special_tokens: - if token.value not in self.config.special_filters: + for sp_token in search_query.special_tokens: + if sp_token.value not in self.config.special_filters: raise errors.SearchError( 'Unknown special token: %r. ' 'Available special tokens: %r.' % ( - token.value, + sp_token.value, _format_dict_keys(self.config.special_filters))) - db_query = self.config.special_filters[token.value]( - db_query, token.negated) + db_query = self.config.special_filters[sp_token.value]( + db_query, None, sp_token.negated) if use_sort: - for token in search_query.sort_tokens: - if token.name not in self.config.sort_columns: + for sort_token in search_query.sort_tokens: + if sort_token.name not in self.config.sort_columns: raise errors.SearchError( 'Unknown sort token: %r. ' 'Available sort tokens: %r.' % ( - token.name, + sort_token.name, _format_dict_keys(self.config.sort_columns))) - column, default_order = self.config.sort_columns[token.name] - order = _get_order(token.order, default_order) - if order == token.SORT_ASC: + column, default_order = ( + self.config.sort_columns[sort_token.name]) + order = _get_order(sort_token.order, default_order) + if order == sort_token.SORT_ASC: db_query = db_query.order_by(column.asc()) - elif order == token.SORT_DESC: + elif order == sort_token.SORT_DESC: db_query = db_query.order_by(column.desc()) db_query = self.config.finalize_query(db_query) diff --git a/server/szurubooru/search/parser.py b/server/szurubooru/search/parser.py index 33b41173..93affe26 100644 --- a/server/szurubooru/search/parser.py +++ b/server/szurubooru/search/parser.py @@ -1,9 +1,12 @@ import re +from typing import List from szurubooru import errors from szurubooru.search import criteria, tokens +from szurubooru.search.query import SearchQuery -def _create_criterion(original_value, value): +def _create_criterion( + original_value: str, value: str) -> criteria.BaseCriterion: if ',' in value: return criteria.ArrayCriterion( original_value, value.split(',')) @@ -15,12 +18,12 @@ def _create_criterion(original_value, value): return criteria.PlainCriterion(original_value, value) -def _parse_anonymous(value, negated): +def _parse_anonymous(value: str, negated: bool) -> tokens.AnonymousToken: criterion = _create_criterion(value, value) return tokens.AnonymousToken(criterion, negated) -def _parse_named(key, value, negated): +def _parse_named(key: str, value: str, negated: bool) -> tokens.NamedToken: original_value = value if key.endswith('-min'): key = key[:-4] @@ -32,11 +35,11 @@ def _parse_named(key, value, negated): return tokens.NamedToken(key, criterion, negated) -def _parse_special(value, negated): +def _parse_special(value: str, negated: bool) -> tokens.SpecialToken: return tokens.SpecialToken(value, negated) -def _parse_sort(value, negated): +def _parse_sort(value: str, negated: bool) -> tokens.SortToken: if value.count(',') == 0: order_str = None elif value.count(',') == 1: @@ -67,23 +70,8 @@ def _parse_sort(value, negated): return tokens.SortToken(value, order) -class SearchQuery: - def __init__(self): - self.anonymous_tokens = [] - self.named_tokens = [] - self.special_tokens = [] - self.sort_tokens = [] - - def __hash__(self): - return hash(( - tuple(self.anonymous_tokens), - tuple(self.named_tokens), - tuple(self.special_tokens), - tuple(self.sort_tokens))) - - class Parser: - def parse(self, query_text): + def parse(self, query_text: str) -> SearchQuery: query = SearchQuery() for chunk in re.split(r'\s+', (query_text or '').lower()): if not chunk: diff --git a/server/szurubooru/search/query.py b/server/szurubooru/search/query.py new file mode 100644 index 00000000..7d29dbd3 --- /dev/null +++ b/server/szurubooru/search/query.py @@ -0,0 +1,16 @@ +from szurubooru.search import tokens + + +class SearchQuery: + def __init__(self) -> None: + self.anonymous_tokens = [] # type: List[tokens.AnonymousToken] + self.named_tokens = [] # type: List[tokens.NamedToken] + self.special_tokens = [] # type: List[tokens.SpecialToken] + self.sort_tokens = [] # type: List[tokens.SortToken] + + def __hash__(self) -> int: + return hash(( + tuple(self.anonymous_tokens), + tuple(self.named_tokens), + tuple(self.special_tokens), + tuple(self.sort_tokens))) diff --git a/server/szurubooru/search/tokens.py b/server/szurubooru/search/tokens.py index cff7dc5f..0cd7fd7d 100644 --- a/server/szurubooru/search/tokens.py +++ b/server/szurubooru/search/tokens.py @@ -1,39 +1,44 @@ +from szurubooru.search.criteria import BaseCriterion + + class AnonymousToken: - def __init__(self, criterion, negated): + def __init__(self, criterion: BaseCriterion, negated: bool) -> None: self.criterion = criterion self.negated = negated - def __hash__(self): + def __hash__(self) -> int: return hash((self.criterion, self.negated)) class NamedToken(AnonymousToken): - def __init__(self, name, criterion, negated): + def __init__( + self, name: str, criterion: BaseCriterion, negated: bool) -> None: super().__init__(criterion, negated) self.name = name - def __hash__(self): + def __hash__(self) -> int: return hash((self.name, self.criterion, self.negated)) class SortToken: SORT_DESC = 'desc' SORT_ASC = 'asc' + SORT_NONE = '' SORT_DEFAULT = 'default' SORT_NEGATED_DEFAULT = 'negated default' - def __init__(self, name, order): + def __init__(self, name: str, order: str) -> None: self.name = name self.order = order - def __hash__(self): + def __hash__(self) -> int: return hash((self.name, self.order)) class SpecialToken: - def __init__(self, value, negated): + def __init__(self, value: str, negated: bool) -> None: self.value = value self.negated = negated - def __hash__(self): + def __hash__(self) -> int: return hash((self.value, self.negated)) diff --git a/server/szurubooru/search/typing.py b/server/szurubooru/search/typing.py new file mode 100644 index 00000000..ebb1b30d --- /dev/null +++ b/server/szurubooru/search/typing.py @@ -0,0 +1,6 @@ +from typing import Any, Callable + + +SaColumn = Any +SaQuery = Any +SaQueryFactory = Callable[[], SaQuery] diff --git a/server/szurubooru/tests/api/test_comment_creating.py b/server/szurubooru/tests/api/test_comment_creating.py index c7d0b0f6..ad243661 100644 --- a/server/szurubooru/tests/api/test_comment_creating.py +++ b/server/szurubooru/tests/api/test_comment_creating.py @@ -1,19 +1,20 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments, posts @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}}) + config_injector( + {'privileges': {'comments:create': model.User.RANK_REGULAR}}) def test_creating_comment( user_factory, post_factory, context_factory, fake_datetime): post = post_factory() - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) db.session.add_all([post, user]) db.session.flush() with patch('szurubooru.func.comments.serialize_comment'), \ @@ -24,7 +25,7 @@ def test_creating_comment( params={'text': 'input', 'postId': post.post_id}, user=user)) assert result == 'serialized comment' - comment = db.session.query(db.Comment).one() + comment = db.session.query(model.Comment).one() assert comment.text == 'input' assert comment.creation_time == datetime(1997, 1, 1) assert comment.last_edit_time is None @@ -41,7 +42,7 @@ def test_creating_comment( def test_trying_to_pass_invalid_params( user_factory, post_factory, context_factory, params): post = post_factory() - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) db.session.add_all([post, user]) db.session.flush() real_params = {'text': 'input', 'postId': post.post_id} @@ -63,11 +64,11 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): api.comment_api.create_comment( context_factory( params={}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_comment_non_existing(user_factory, context_factory): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) db.session.add_all([user]) db.session.flush() with pytest.raises(posts.PostNotFoundError): @@ -81,4 +82,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory): api.comment_api.create_comment( context_factory( params={}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_comment_deleting.py b/server/szurubooru/tests/api/test_comment_deleting.py index efb432a6..e1d1baa0 100644 --- a/server/szurubooru/tests/api/test_comment_deleting.py +++ b/server/szurubooru/tests/api/test_comment_deleting.py @@ -1,5 +1,5 @@ import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments @@ -7,8 +7,8 @@ from szurubooru.func import comments def inject_config(config_injector): config_injector({ 'privileges': { - 'comments:delete:own': db.User.RANK_REGULAR, - 'comments:delete:any': db.User.RANK_MODERATOR, + 'comments:delete:own': model.User.RANK_REGULAR, + 'comments:delete:any': model.User.RANK_MODERATOR, }, }) @@ -22,26 +22,26 @@ def test_deleting_own_comment(user_factory, comment_factory, context_factory): context_factory(params={'version': 1}, user=user), {'comment_id': comment.comment_id}) assert result == {} - assert db.session.query(db.Comment).count() == 0 + assert db.session.query(model.Comment).count() == 0 def test_deleting_someones_else_comment( user_factory, comment_factory, context_factory): - user1 = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_MODERATOR) + user1 = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_MODERATOR) comment = comment_factory(user=user1) db.session.add(comment) db.session.commit() api.comment_api.delete_comment( context_factory(params={'version': 1}, user=user2), {'comment_id': comment.comment_id}) - assert db.session.query(db.Comment).count() == 0 + assert db.session.query(model.Comment).count() == 0 def test_trying_to_delete_someones_else_comment_without_privileges( user_factory, comment_factory, context_factory): - user1 = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_REGULAR) + user1 = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user1) db.session.add(comment) db.session.commit() @@ -49,7 +49,7 @@ def test_trying_to_delete_someones_else_comment_without_privileges( api.comment_api.delete_comment( context_factory(params={'version': 1}, user=user2), {'comment_id': comment.comment_id}) - assert db.session.query(db.Comment).count() == 1 + assert db.session.query(model.Comment).count() == 1 def test_trying_to_delete_non_existing(user_factory, context_factory): @@ -57,5 +57,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory): api.comment_api.delete_comment( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': 1}) diff --git a/server/szurubooru/tests/api/test_comment_rating.py b/server/szurubooru/tests/api/test_comment_rating.py index 981e0dd8..aae5e241 100644 --- a/server/szurubooru/tests/api/test_comment_rating.py +++ b/server/szurubooru/tests/api/test_comment_rating.py @@ -1,17 +1,18 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}}) + config_injector( + {'privileges': {'comments:score': model.User.RANK_REGULAR}}) def test_simple_rating( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -22,14 +23,14 @@ def test_simple_rating( context_factory(params={'score': 1}, user=user), {'comment_id': comment.comment_id}) assert result == 'serialized comment' - assert db.session.query(db.CommentScore).count() == 1 + assert db.session.query(model.CommentScore).count() == 1 assert comment is not None assert comment.score == 1 def test_updating_rating( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -42,14 +43,14 @@ def test_updating_rating( api.comment_api.set_comment_score( context_factory(params={'score': -1}, user=user), {'comment_id': comment.comment_id}) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 1 + comment = db.session.query(model.Comment).one() + assert db.session.query(model.CommentScore).count() == 1 assert comment.score == -1 def test_updating_rating_to_zero( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -62,14 +63,14 @@ def test_updating_rating_to_zero( api.comment_api.set_comment_score( context_factory(params={'score': 0}, user=user), {'comment_id': comment.comment_id}) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 0 + comment = db.session.query(model.Comment).one() + assert db.session.query(model.CommentScore).count() == 0 assert comment.score == 0 def test_deleting_rating( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -82,15 +83,15 @@ def test_deleting_rating( api.comment_api.delete_comment_score( context_factory(user=user), {'comment_id': comment.comment_id}) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 0 + comment = db.session.query(model.Comment).one() + assert db.session.query(model.CommentScore).count() == 0 assert comment.score == 0 def test_ratings_from_multiple_users( user_factory, comment_factory, context_factory, fake_datetime): - user1 = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_REGULAR) + user1 = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory() db.session.add_all([user1, user2, comment]) db.session.commit() @@ -103,8 +104,8 @@ def test_ratings_from_multiple_users( api.comment_api.set_comment_score( context_factory(params={'score': -1}, user=user2), {'comment_id': comment.comment_id}) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 2 + comment = db.session.query(model.Comment).one() + assert db.session.query(model.CommentScore).count() == 2 assert comment.score == 0 @@ -125,7 +126,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): api.comment_api.set_comment_score( context_factory( params={'score': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': 5}) @@ -138,5 +139,5 @@ def test_trying_to_rate_without_privileges( api.comment_api.set_comment_score( context_factory( params={'score': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'comment_id': comment.comment_id}) diff --git a/server/szurubooru/tests/api/test_comment_retrieving.py b/server/szurubooru/tests/api/test_comment_retrieving.py index 908e9eb8..e0378fa2 100644 --- a/server/szurubooru/tests/api/test_comment_retrieving.py +++ b/server/szurubooru/tests/api/test_comment_retrieving.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments @@ -8,8 +8,8 @@ from szurubooru.func import comments def inject_config(config_injector): config_injector({ 'privileges': { - 'comments:list': db.User.RANK_REGULAR, - 'comments:view': db.User.RANK_REGULAR, + 'comments:list': model.User.RANK_REGULAR, + 'comments:view': model.User.RANK_REGULAR, }, }) @@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, comment_factory, context_factory): result = api.comment_api.get_comments( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == { 'query': '', 'page': 1, @@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges( api.comment_api.get_comments( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_retrieving_single(user_factory, comment_factory, context_factory): @@ -51,7 +51,7 @@ def test_retrieving_single(user_factory, comment_factory, context_factory): comments.serialize_comment.return_value = 'serialized comment' result = api.comment_api.get_comment( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': comment.comment_id}) assert result == 'serialized comment' @@ -60,7 +60,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): api.comment_api.get_comment( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': 5}) @@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): api.comment_api.get_comment( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'comment_id': 5}) diff --git a/server/szurubooru/tests/api/test_comment_updating.py b/server/szurubooru/tests/api/test_comment_updating.py index 5f3d12b0..761b1ce0 100644 --- a/server/szurubooru/tests/api/test_comment_updating.py +++ b/server/szurubooru/tests/api/test_comment_updating.py @@ -1,7 +1,7 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments @@ -9,15 +9,15 @@ from szurubooru.func import comments def inject_config(config_injector): config_injector({ 'privileges': { - 'comments:edit:own': db.User.RANK_REGULAR, - 'comments:edit:any': db.User.RANK_MODERATOR, + 'comments:edit:own': model.User.RANK_REGULAR, + 'comments:edit:any': model.User.RANK_MODERATOR, }, }) def test_simple_updating( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -73,14 +73,14 @@ def test_trying_to_update_non_existing(user_factory, context_factory): api.comment_api.update_comment( context_factory( params={'text': 'new text'}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': 5}) def test_trying_to_update_someones_comment_without_privileges( user_factory, comment_factory, context_factory): - user = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -93,8 +93,8 @@ def test_trying_to_update_someones_comment_without_privileges( def test_updating_someones_comment_with_privileges( user_factory, comment_factory, context_factory): - user = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_MODERATOR) + user = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_MODERATOR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() diff --git a/server/szurubooru/tests/api/test_password_reset.py b/server/szurubooru/tests/api/test_password_reset.py index 52b568da..e46dbbec 100644 --- a/server/szurubooru/tests/api/test_password_reset.py +++ b/server/szurubooru/tests/api/test_password_reset.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import auth, mailer @@ -15,7 +15,7 @@ def inject_config(config_injector): def test_reset_sending_email(context_factory, user_factory): db.session.add(user_factory( - name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) + name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')) db.session.flush() for initiating_user in ['u1', 'user@example.com']: with patch('szurubooru.func.mailer.send_mail'): @@ -39,7 +39,7 @@ def test_trying_to_reset_non_existing(context_factory): def test_trying_to_reset_without_email(context_factory, user_factory): db.session.add( - user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None)) + user_factory(name='u1', rank=model.User.RANK_REGULAR, email=None)) db.session.flush() with pytest.raises(errors.ValidationError): api.password_reset_api.start_password_reset( @@ -48,7 +48,7 @@ def test_trying_to_reset_without_email(context_factory, user_factory): def test_confirming_with_good_token(context_factory, user_factory): user = user_factory( - name='u1', rank=db.User.RANK_REGULAR, email='user@example.com') + name='u1', rank=model.User.RANK_REGULAR, email='user@example.com') old_hash = user.password_hash db.session.add(user) db.session.flush() @@ -68,7 +68,7 @@ def test_trying_to_confirm_non_existing(context_factory): def test_trying_to_confirm_without_token(context_factory, user_factory): db.session.add(user_factory( - name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) + name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')) db.session.flush() with pytest.raises(errors.ValidationError): api.password_reset_api.finish_password_reset( @@ -77,7 +77,7 @@ def test_trying_to_confirm_without_token(context_factory, user_factory): def test_trying_to_confirm_with_bad_token(context_factory, user_factory): db.session.add(user_factory( - name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) + name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')) db.session.flush() with pytest.raises(errors.ValidationError): api.password_reset_api.finish_password_reset( diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py index 9737a73b..a653b3bf 100644 --- a/server/szurubooru/tests/api/test_post_creating.py +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, tags, snapshots, net @@ -8,16 +8,16 @@ from szurubooru.func import posts, tags, snapshots, net def inject_config(config_injector): config_injector({ 'privileges': { - 'posts:create:anonymous': db.User.RANK_REGULAR, - 'posts:create:identified': db.User.RANK_REGULAR, - 'tags:create': db.User.RANK_REGULAR, + 'posts:create:anonymous': model.User.RANK_REGULAR, + 'posts:create:identified': model.User.RANK_REGULAR, + 'tags:create': model.User.RANK_REGULAR, }, }) def test_creating_minimal_posts( context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -53,20 +53,20 @@ def test_creating_minimal_posts( posts.update_post_thumbnail.assert_called_once_with( post, 'post-thumbnail') posts.update_post_safety.assert_called_once_with(post, 'safe') - posts.update_post_source.assert_called_once_with(post, None) + posts.update_post_source.assert_called_once_with(post, '') posts.update_post_relations.assert_called_once_with(post, []) posts.update_post_notes.assert_called_once_with(post, []) posts.update_post_flags.assert_called_once_with(post, []) posts.update_post_thumbnail.assert_called_once_with( post, 'post-thumbnail') posts.serialize_post.assert_called_once_with( - post, auth_user, options=None) + post, auth_user, options=[]) snapshots.create.assert_called_once_with(post, auth_user) tags.export_to_json.assert_called_once_with() def test_creating_full_posts(context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -109,14 +109,14 @@ def test_creating_full_posts(context_factory, post_factory, user_factory): posts.update_post_flags.assert_called_once_with( post, ['flag1', 'flag2']) posts.serialize_post.assert_called_once_with( - post, auth_user, options=None) + post, auth_user, options=[]) snapshots.create.assert_called_once_with(post, auth_user) tags.export_to_json.assert_called_once_with() def test_anonymous_uploads( config_injector, context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -126,7 +126,7 @@ def test_anonymous_uploads( patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.update_post_source'): config_injector({ - 'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR}, + 'privileges': {'posts:create:anonymous': model.User.RANK_REGULAR}, }) posts.create_post.return_value = [post, []] api.post_api.create_post( @@ -146,7 +146,7 @@ def test_anonymous_uploads( def test_creating_from_url_saves_source( config_injector, context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -157,7 +157,7 @@ def test_creating_from_url_saves_source( patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.update_post_source'): config_injector({ - 'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, + 'privileges': {'posts:create:identified': model.User.RANK_REGULAR}, }) net.download.return_value = b'content' posts.create_post.return_value = [post, []] @@ -177,7 +177,7 @@ def test_creating_from_url_saves_source( def test_creating_from_url_with_source_specified( config_injector, context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -188,7 +188,7 @@ def test_creating_from_url_with_source_specified( patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.update_post_source'): config_injector({ - 'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, + 'privileges': {'posts:create:identified': model.User.RANK_REGULAR}, }) net.download.return_value = b'content' posts.create_post.return_value = [post, []] @@ -218,14 +218,14 @@ def test_trying_to_omit_mandatory_field(context_factory, user_factory, field): context_factory( params=params, files={'content': '...'}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) @pytest.mark.parametrize( 'field', ['tags', 'relations', 'source', 'notes', 'flags']) def test_omitting_optional_field( field, context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -268,10 +268,10 @@ def test_errors_not_spending_ids( 'post_height': 300, }, 'privileges': { - 'posts:create:identified': db.User.RANK_REGULAR, + 'posts:create:identified': model.User.RANK_REGULAR, }, }) - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) # successful request with patch('szurubooru.func.posts.serialize_post'), \ @@ -316,7 +316,7 @@ def test_trying_to_omit_content(context_factory, user_factory): 'safety': 'safe', 'tags': ['tag1', 'tag2'], }, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_create_post_without_privileges( @@ -324,16 +324,16 @@ def test_trying_to_create_post_without_privileges( with pytest.raises(errors.AuthError): api.post_api.create_post(context_factory( params='whatever', - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_trying_to_create_tags_without_privileges( config_injector, context_factory, user_factory): config_injector({ 'privileges': { - 'posts:create:anonymous': db.User.RANK_REGULAR, - 'posts:create:identified': db.User.RANK_REGULAR, - 'tags:create': db.User.RANK_ADMINISTRATOR, + 'posts:create:anonymous': model.User.RANK_REGULAR, + 'posts:create:identified': model.User.RANK_REGULAR, + 'tags:create': model.User.RANK_ADMINISTRATOR, }, }) with pytest.raises(errors.AuthError), \ @@ -349,4 +349,4 @@ def test_trying_to_create_tags_without_privileges( files={ 'content': posts.EMPTY_PIXEL, }, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) diff --git a/server/szurubooru/tests/api/test_post_deleting.py b/server/szurubooru/tests/api/test_post_deleting.py index c4187ed4..643b952c 100644 --- a/server/szurubooru/tests/api/test_post_deleting.py +++ b/server/szurubooru/tests/api/test_post_deleting.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'posts:delete': model.User.RANK_REGULAR}}) def test_deleting(user_factory, post_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory(id=1) db.session.add(post) db.session.flush() @@ -20,7 +20,7 @@ def test_deleting(user_factory, post_factory, context_factory): context_factory(params={'version': 1}, user=auth_user), {'post_id': 1}) assert result == {} - assert db.session.query(db.Post).count() == 0 + assert db.session.query(model.Post).count() == 0 snapshots.delete.assert_called_once_with(post, auth_user) tags.export_to_json.assert_called_once_with() @@ -28,7 +28,7 @@ def test_deleting(user_factory, post_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.delete_post( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': 999}) @@ -38,6 +38,6 @@ def test_trying_to_delete_without_privileges( db.session.commit() with pytest.raises(errors.AuthError): api.post_api.delete_post( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': 1}) - assert db.session.query(db.Post).count() == 1 + assert db.session.query(model.Post).count() == 1 diff --git a/server/szurubooru/tests/api/test_post_favoriting.py b/server/szurubooru/tests/api/test_post_favoriting.py index d78d199e..ce91a028 100644 --- a/server/szurubooru/tests/api/test_post_favoriting.py +++ b/server/szurubooru/tests/api/test_post_favoriting.py @@ -1,13 +1,14 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}}) + config_injector( + {'privileges': {'posts:favorite': model.User.RANK_REGULAR}}) def test_adding_to_favorites( @@ -23,8 +24,8 @@ def test_adding_to_favorites( context_factory(user=user_factory()), {'post_id': post.post_id}) assert result == 'serialized post' - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 1 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostFavorite).count() == 1 assert post is not None assert post.favorite_count == 1 assert post.score == 1 @@ -47,9 +48,9 @@ def test_removing_from_favorites( api.post_api.delete_post_from_favorites( context_factory(user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() + post = db.session.query(model.Post).one() assert post.score == 1 - assert db.session.query(db.PostFavorite).count() == 0 + assert db.session.query(model.PostFavorite).count() == 0 assert post.favorite_count == 0 @@ -68,8 +69,8 @@ def test_favoriting_twice( api.post_api.add_post_to_favorites( context_factory(user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 1 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostFavorite).count() == 1 assert post.favorite_count == 1 @@ -92,8 +93,8 @@ def test_removing_twice( api.post_api.delete_post_from_favorites( context_factory(user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 0 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostFavorite).count() == 0 assert post.favorite_count == 0 @@ -113,8 +114,8 @@ def test_favorites_from_multiple_users( api.post_api.add_post_to_favorites( context_factory(user=user2), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 2 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostFavorite).count() == 2 assert post.favorite_count == 2 assert post.last_favorite_time == datetime(1997, 12, 2) @@ -133,5 +134,5 @@ def test_trying_to_rate_without_privileges( db.session.commit() with pytest.raises(errors.AuthError): api.post_api.add_post_to_favorites( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': post.post_id}) diff --git a/server/szurubooru/tests/api/test_post_featuring.py b/server/szurubooru/tests/api/test_post_featuring.py index a0a82c75..88e4e001 100644 --- a/server/szurubooru/tests/api/test_post_featuring.py +++ b/server/szurubooru/tests/api/test_post_featuring.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, snapshots @@ -8,14 +8,14 @@ from szurubooru.func import posts, snapshots def inject_config(config_injector): config_injector({ 'privileges': { - 'posts:feature': db.User.RANK_REGULAR, - 'posts:view': db.User.RANK_REGULAR, + 'posts:feature': model.User.RANK_REGULAR, + 'posts:view': model.User.RANK_REGULAR, }, }) def test_featuring(user_factory, post_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory(id=1) db.session.add(post) db.session.flush() @@ -31,7 +31,7 @@ def test_featuring(user_factory, post_factory, context_factory): assert posts.get_post_by_id(1).is_featured result = api.post_api.get_featured_post( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == 'serialized post' snapshots.modify.assert_called_once_with(post, auth_user) @@ -40,7 +40,7 @@ def test_trying_to_omit_required_parameter(user_factory, context_factory): with pytest.raises(errors.MissingRequiredParameterError): api.post_api.set_featured_post( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_feature_the_same_post_twice( @@ -51,12 +51,12 @@ def test_trying_to_feature_the_same_post_twice( api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) with pytest.raises(posts.PostAlreadyFeaturedError): api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_featuring_one_post_after_another( @@ -72,12 +72,12 @@ def test_featuring_one_post_after_another( api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) with fake_datetime('1998'): api.post_api.set_featured_post( context_factory( params={'id': 2}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert posts.try_get_featured_post() is not None assert posts.try_get_featured_post().post_id == 2 assert not posts.get_post_by_id(1).is_featured @@ -89,7 +89,7 @@ def test_trying_to_feature_non_existing(user_factory, context_factory): api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_feature_without_privileges(user_factory, context_factory): @@ -97,10 +97,10 @@ def test_trying_to_feature_without_privileges(user_factory, context_factory): api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_getting_featured_post_without_privileges_to_view( user_factory, context_factory): api.post_api.get_featured_post( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS))) + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_post_merging.py b/server/szurubooru/tests/api/test_post_merging.py index e6540904..eb8464f8 100644 --- a/server/szurubooru/tests/api/test_post_merging.py +++ b/server/szurubooru/tests/api/test_post_merging.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'posts:merge': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'posts:merge': model.User.RANK_REGULAR}}) def test_merging(user_factory, context_factory, post_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) source_post = post_factory() target_post = post_factory() db.session.add_all([source_post, target_post]) @@ -25,6 +25,7 @@ def test_merging(user_factory, context_factory, post_factory): 'mergeToVersion': 1, 'remove': source_post.post_id, 'mergeTo': target_post.post_id, + 'replaceContent': False, }, user=auth_user)) posts.merge_posts.called_once_with(source_post, target_post) @@ -45,13 +46,14 @@ def test_trying_to_omit_mandatory_field( 'mergeToVersion': 1, 'remove': source_post.post_id, 'mergeTo': target_post.post_id, + 'replaceContent': False, } del params[field] with pytest.raises(errors.ValidationError): api.post_api.merge_posts( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_merge_non_existing( @@ -63,12 +65,12 @@ def test_trying_to_merge_non_existing( api.post_api.merge_posts( context_factory( params={'remove': post.post_id, 'mergeTo': 999}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) with pytest.raises(posts.PostNotFoundError): api.post_api.merge_posts( context_factory( params={'remove': 999, 'mergeTo': post.post_id}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_merge_without_privileges( @@ -85,5 +87,6 @@ def test_trying_to_merge_without_privileges( 'mergeToVersion': 1, 'remove': source_post.post_id, 'mergeTo': target_post.post_id, + 'replaceContent': False, }, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_post_rating.py b/server/szurubooru/tests/api/test_post_rating.py index 18e823e7..0fca2f56 100644 --- a/server/szurubooru/tests/api/test_post_rating.py +++ b/server/szurubooru/tests/api/test_post_rating.py @@ -1,12 +1,12 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'posts:score': model.User.RANK_REGULAR}}) def test_simple_rating( @@ -22,8 +22,8 @@ def test_simple_rating( params={'score': 1}, user=user_factory()), {'post_id': post.post_id}) assert result == 'serialized post' - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 1 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 1 assert post is not None assert post.score == 1 @@ -43,8 +43,8 @@ def test_updating_rating( api.post_api.set_post_score( context_factory(params={'score': -1}, user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 1 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 1 assert post.score == -1 @@ -63,8 +63,8 @@ def test_updating_rating_to_zero( api.post_api.set_post_score( context_factory(params={'score': 0}, user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 0 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 0 assert post.score == 0 @@ -83,8 +83,8 @@ def test_deleting_rating( api.post_api.delete_post_score( context_factory(user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 0 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 0 assert post.score == 0 @@ -104,8 +104,8 @@ def test_ratings_from_multiple_users( api.post_api.set_post_score( context_factory(params={'score': -1}, user=user2), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 2 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 2 assert post.score == 0 @@ -136,5 +136,5 @@ def test_trying_to_rate_without_privileges( api.post_api.set_post_score( context_factory( params={'score': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': post.post_id}) diff --git a/server/szurubooru/tests/api/test_post_retrieving.py b/server/szurubooru/tests/api/test_post_retrieving.py index a02c7bc1..9d9db72a 100644 --- a/server/szurubooru/tests/api/test_post_retrieving.py +++ b/server/szurubooru/tests/api/test_post_retrieving.py @@ -1,7 +1,7 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts @@ -9,8 +9,8 @@ from szurubooru.func import posts def inject_config(config_injector): config_injector({ 'privileges': { - 'posts:list': db.User.RANK_REGULAR, - 'posts:view': db.User.RANK_REGULAR, + 'posts:list': model.User.RANK_REGULAR, + 'posts:view': model.User.RANK_REGULAR, }, }) @@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory): result = api.post_api.get_posts( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == { 'query': '', 'page': 1, @@ -36,10 +36,10 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory): def test_using_special_tokens(user_factory, post_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post1 = post_factory(id=1) post2 = post_factory(id=2) - post1.favorited_by = [db.PostFavorite( + post1.favorited_by = [model.PostFavorite( user=auth_user, time=datetime.utcnow())] db.session.add_all([post1, post2, auth_user]) db.session.flush() @@ -68,7 +68,7 @@ def test_trying_to_use_special_tokens_without_logging_in( api.post_api.get_posts( context_factory( params={'query': 'special:fav', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_trying_to_retrieve_multiple_without_privileges( @@ -77,7 +77,7 @@ def test_trying_to_retrieve_multiple_without_privileges( api.post_api.get_posts( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_retrieving_single(user_factory, post_factory, context_factory): @@ -86,7 +86,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory): with patch('szurubooru.func.posts.serialize_post'): posts.serialize_post.return_value = 'serialized post' result = api.post_api.get_post( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': 1}) assert result == 'serialized post' @@ -94,7 +94,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.get_post( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': 999}) @@ -102,5 +102,5 @@ def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): api.post_api.get_post( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': 999}) diff --git a/server/szurubooru/tests/api/test_post_updating.py b/server/szurubooru/tests/api/test_post_updating.py index 790e835e..d3649307 100644 --- a/server/szurubooru/tests/api/test_post_updating.py +++ b/server/szurubooru/tests/api/test_post_updating.py @@ -1,7 +1,7 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, tags, snapshots, net @@ -9,22 +9,22 @@ from szurubooru.func import posts, tags, snapshots, net def inject_config(config_injector): config_injector({ 'privileges': { - 'posts:edit:tags': db.User.RANK_REGULAR, - 'posts:edit:content': db.User.RANK_REGULAR, - 'posts:edit:safety': db.User.RANK_REGULAR, - 'posts:edit:source': db.User.RANK_REGULAR, - 'posts:edit:relations': db.User.RANK_REGULAR, - 'posts:edit:notes': db.User.RANK_REGULAR, - 'posts:edit:flags': db.User.RANK_REGULAR, - 'posts:edit:thumbnail': db.User.RANK_REGULAR, - 'tags:create': db.User.RANK_MODERATOR, + 'posts:edit:tags': model.User.RANK_REGULAR, + 'posts:edit:content': model.User.RANK_REGULAR, + 'posts:edit:safety': model.User.RANK_REGULAR, + 'posts:edit:source': model.User.RANK_REGULAR, + 'posts:edit:relations': model.User.RANK_REGULAR, + 'posts:edit:notes': model.User.RANK_REGULAR, + 'posts:edit:flags': model.User.RANK_REGULAR, + 'posts:edit:thumbnail': model.User.RANK_REGULAR, + 'tags:create': model.User.RANK_MODERATOR, }, }) def test_post_updating( context_factory, post_factory, user_factory, fake_datetime): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -76,7 +76,7 @@ def test_post_updating( posts.update_post_flags.assert_called_once_with( post, ['flag1', 'flag2']) posts.serialize_post.assert_called_once_with( - post, auth_user, options=None) + post, auth_user, options=[]) snapshots.modify.assert_called_once_with(post, auth_user) tags.export_to_json.assert_called_once_with() assert post.last_edit_time == datetime(1997, 1, 1) @@ -97,7 +97,7 @@ def test_uploading_from_url_saves_source( api.post_api.update_post( context_factory( params={'contentUrl': 'example.com', 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': post.post_id}) net.download.assert_called_once_with('example.com') posts.update_post_content.assert_called_once_with(post, b'content') @@ -122,7 +122,7 @@ def test_uploading_from_url_with_source_specified( 'contentUrl': 'example.com', 'source': 'example2.com', 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': post.post_id}) net.download.assert_called_once_with('example.com') posts.update_post_content.assert_called_once_with(post, b'content') @@ -134,7 +134,7 @@ def test_trying_to_update_non_existing(context_factory, user_factory): api.post_api.update_post( context_factory( params='whatever', - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': 1}) @@ -158,7 +158,7 @@ def test_trying_to_update_field_without_privileges( context_factory( params={**params, **{'version': 1}}, files=files, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': post.post_id}) @@ -173,5 +173,5 @@ def test_trying_to_create_tags_without_privileges( api.post_api.update_post( context_factory( params={'tags': ['tag1', 'tag2'], 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': post.post_id}) diff --git a/server/szurubooru/tests/api/test_snapshot_retrieving.py b/server/szurubooru/tests/api/test_snapshot_retrieving.py index 73b6f060..facbcd8a 100644 --- a/server/szurubooru/tests/api/test_snapshot_retrieving.py +++ b/server/szurubooru/tests/api/test_snapshot_retrieving.py @@ -1,10 +1,10 @@ from datetime import datetime import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors def snapshot_factory(): - snapshot = db.Snapshot() + snapshot = model.Snapshot() snapshot.creation_time = datetime(1999, 1, 1) snapshot.resource_type = 'dummy' snapshot.resource_pkey = 1 @@ -17,7 +17,7 @@ def snapshot_factory(): @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ - 'privileges': {'snapshots:list': db.User.RANK_REGULAR}, + 'privileges': {'snapshots:list': model.User.RANK_REGULAR}, }) @@ -29,7 +29,7 @@ def test_retrieving_multiple(user_factory, context_factory): result = api.snapshot_api.get_snapshots( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result['query'] == '' assert result['page'] == 1 assert result['pageSize'] == 100 @@ -43,4 +43,4 @@ def test_trying_to_retrieve_multiple_without_privileges( api.snapshot_api.get_snapshots( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_category_creating.py b/server/szurubooru/tests/api/test_tag_category_creating.py index 96afc390..fbd8b1bc 100644 --- a/server/szurubooru/tests/api/test_tag_category_creating.py +++ b/server/szurubooru/tests/api/test_tag_category_creating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tag_categories, tags, snapshots @@ -11,13 +11,13 @@ def _update_category_name(category, name): @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ - 'privileges': {'tag_categories:create': db.User.RANK_REGULAR}, + 'privileges': {'tag_categories:create': model.User.RANK_REGULAR}, }) def test_creating_category( tag_category_factory, user_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) category = tag_category_factory(name='meta') db.session.add(category) @@ -49,7 +49,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): api.tag_category_api.create_tag_category( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_create_without_privileges(user_factory, context_factory): @@ -57,4 +57,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory): api.tag_category_api.create_tag_category( context_factory( params={'name': 'meta', 'color': 'black'}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_category_deleting.py b/server/szurubooru/tests/api/test_tag_category_deleting.py index 1f1cde4c..1fc86431 100644 --- a/server/szurubooru/tests/api/test_tag_category_deleting.py +++ b/server/szurubooru/tests/api/test_tag_category_deleting.py @@ -1,18 +1,18 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tag_categories, tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ - 'privileges': {'tag_categories:delete': db.User.RANK_REGULAR}, + 'privileges': {'tag_categories:delete': model.User.RANK_REGULAR}, }) def test_deleting(user_factory, tag_category_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) category = tag_category_factory(name='category') db.session.add(tag_category_factory(name='root')) db.session.add(category) @@ -23,8 +23,8 @@ def test_deleting(user_factory, tag_category_factory, context_factory): context_factory(params={'version': 1}, user=auth_user), {'category_name': 'category'}) assert result == {} - assert db.session.query(db.TagCategory).count() == 1 - assert db.session.query(db.TagCategory).one().name == 'root' + assert db.session.query(model.TagCategory).count() == 1 + assert db.session.query(model.TagCategory).one().name == 'root' snapshots.delete.assert_called_once_with(category, auth_user) tags.export_to_json.assert_called_once_with() @@ -41,9 +41,9 @@ def test_trying_to_delete_used( api.tag_category_api.delete_tag_category( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'category'}) - assert db.session.query(db.TagCategory).count() == 1 + assert db.session.query(model.TagCategory).count() == 1 def test_trying_to_delete_last( @@ -54,14 +54,14 @@ def test_trying_to_delete_last( api.tag_category_api.delete_tag_category( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'root'}) def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): api.tag_category_api.delete_tag_category( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'bad'}) @@ -73,6 +73,6 @@ def test_trying_to_delete_without_privileges( api.tag_category_api.delete_tag_category( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'category_name': 'category'}) - assert db.session.query(db.TagCategory).count() == 1 + assert db.session.query(model.TagCategory).count() == 1 diff --git a/server/szurubooru/tests/api/test_tag_category_retrieving.py b/server/szurubooru/tests/api/test_tag_category_retrieving.py index 4f6610b3..0b98d743 100644 --- a/server/szurubooru/tests/api/test_tag_category_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_category_retrieving.py @@ -1,5 +1,5 @@ import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tag_categories @@ -7,8 +7,8 @@ from szurubooru.func import tag_categories def inject_config(config_injector): config_injector({ 'privileges': { - 'tag_categories:list': db.User.RANK_REGULAR, - 'tag_categories:view': db.User.RANK_REGULAR, + 'tag_categories:list': model.User.RANK_REGULAR, + 'tag_categories:view': model.User.RANK_REGULAR, }, }) @@ -21,7 +21,7 @@ def test_retrieving_multiple( ]) db.session.flush() result = api.tag_category_api.get_tag_categories( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR))) + context_factory(user=user_factory(rank=model.User.RANK_REGULAR))) assert [cat['name'] for cat in result['results']] == ['c1', 'c2'] @@ -30,7 +30,7 @@ def test_retrieving_single( db.session.add(tag_category_factory(name='cat')) db.session.flush() result = api.tag_category_api.get_tag_category( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'cat'}) assert result == { 'name': 'cat', @@ -44,7 +44,7 @@ def test_retrieving_single( def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): api.tag_category_api.get_tag_category( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': '-'}) @@ -52,5 +52,5 @@ def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): api.tag_category_api.get_tag_category( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'category_name': '-'}) diff --git a/server/szurubooru/tests/api/test_tag_category_updating.py b/server/szurubooru/tests/api/test_tag_category_updating.py index 9dd0f6bb..d406dd1f 100644 --- a/server/szurubooru/tests/api/test_tag_category_updating.py +++ b/server/szurubooru/tests/api/test_tag_category_updating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tag_categories, tags, snapshots @@ -12,15 +12,15 @@ def _update_category_name(category, name): def inject_config(config_injector): config_injector({ 'privileges': { - 'tag_categories:edit:name': db.User.RANK_REGULAR, - 'tag_categories:edit:color': db.User.RANK_REGULAR, - 'tag_categories:set_default': db.User.RANK_REGULAR, + 'tag_categories:edit:name': model.User.RANK_REGULAR, + 'tag_categories:edit:color': model.User.RANK_REGULAR, + 'tag_categories:set_default': model.User.RANK_REGULAR, }, }) def test_simple_updating(user_factory, tag_category_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) category = tag_category_factory(name='name', color='black') db.session.add(category) db.session.flush() @@ -61,7 +61,7 @@ def test_omitting_optional_field( api.tag_category_api.update_tag_category( context_factory( params={**params, **{'version': 1}}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'name'}) @@ -70,7 +70,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): api.tag_category_api.update_tag_category( context_factory( params={'name': ['dummy']}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'bad'}) @@ -86,7 +86,7 @@ def test_trying_to_update_without_privileges( api.tag_category_api.update_tag_category( context_factory( params={**params, **{'version': 1}}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'category_name': 'dummy'}) @@ -106,7 +106,7 @@ def test_set_as_default(user_factory, tag_category_factory, context_factory): 'color': 'white', 'version': 1, }, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'name'}) assert result == 'serialized category' tag_categories.set_default_category.assert_called_once_with(category) diff --git a/server/szurubooru/tests/api/test_tag_creating.py b/server/szurubooru/tests/api/test_tag_creating.py index dc056280..771b9f61 100644 --- a/server/szurubooru/tests/api/test_tag_creating.py +++ b/server/szurubooru/tests/api/test_tag_creating.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'tags:create': model.User.RANK_REGULAR}}) def test_creating_simple_tags(tag_factory, user_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) tag = tag_factory() with patch('szurubooru.func.tags.create_tag'), \ patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ @@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): api.tag_api.create_tag( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) @pytest.mark.parametrize('field', ['implications', 'suggestions']) @@ -70,7 +70,7 @@ def test_omitting_optional_field( api.tag_api.create_tag( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_create_tag_without_privileges( @@ -84,4 +84,4 @@ def test_trying_to_create_tag_without_privileges( 'suggestions': ['tag'], 'implications': [], }, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_deleting.py b/server/szurubooru/tests/api/test_tag_deleting.py index a657b02e..fbd35e12 100644 --- a/server/szurubooru/tests/api/test_tag_deleting.py +++ b/server/szurubooru/tests/api/test_tag_deleting.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'tags:delete': model.User.RANK_REGULAR}}) def test_deleting(user_factory, tag_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) tag = tag_factory(names=['tag']) db.session.add(tag) db.session.commit() @@ -20,7 +20,7 @@ def test_deleting(user_factory, tag_factory, context_factory): context_factory(params={'version': 1}, user=auth_user), {'tag_name': 'tag'}) assert result == {} - assert db.session.query(db.Tag).count() == 0 + assert db.session.query(model.Tag).count() == 0 snapshots.delete.assert_called_once_with(tag, auth_user) tags.export_to_json.assert_called_once_with() @@ -36,17 +36,17 @@ def test_deleting_used( api.tag_api.delete_tag( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) db.session.refresh(post) - assert db.session.query(db.Tag).count() == 0 + assert db.session.query(model.Tag).count() == 0 assert post.tags == [] def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.delete_tag( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'bad'}) @@ -58,6 +58,6 @@ def test_trying_to_delete_without_privileges( api.tag_api.delete_tag( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'tag_name': 'tag'}) - assert db.session.query(db.Tag).count() == 1 + assert db.session.query(model.Tag).count() == 1 diff --git a/server/szurubooru/tests/api/test_tag_merging.py b/server/szurubooru/tests/api/test_tag_merging.py index a448c9c4..484fbfa6 100644 --- a/server/szurubooru/tests/api/test_tag_merging.py +++ b/server/szurubooru/tests/api/test_tag_merging.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'tags:merge': model.User.RANK_REGULAR}}) def test_merging(user_factory, tag_factory, context_factory, post_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) source_tag = tag_factory(names=['source']) target_tag = tag_factory(names=['target']) db.session.add_all([source_tag, target_tag]) @@ -62,7 +62,7 @@ def test_trying_to_omit_mandatory_field( api.tag_api.merge_tags( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_merge_non_existing( @@ -73,12 +73,12 @@ def test_trying_to_merge_non_existing( api.tag_api.merge_tags( context_factory( params={'remove': 'good', 'mergeTo': 'bad'}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) with pytest.raises(tags.TagNotFoundError): api.tag_api.merge_tags( context_factory( params={'remove': 'bad', 'mergeTo': 'good'}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_merge_without_privileges( @@ -97,4 +97,4 @@ def test_trying_to_merge_without_privileges( 'remove': 'source', 'mergeTo': 'target', }, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_retrieving.py b/server/szurubooru/tests/api/test_tag_retrieving.py index 86837f97..fd2b2cb5 100644 --- a/server/szurubooru/tests/api/test_tag_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_retrieving.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags @@ -8,8 +8,8 @@ from szurubooru.func import tags def inject_config(config_injector): config_injector({ 'privileges': { - 'tags:list': db.User.RANK_REGULAR, - 'tags:view': db.User.RANK_REGULAR, + 'tags:list': model.User.RANK_REGULAR, + 'tags:view': model.User.RANK_REGULAR, }, }) @@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, tag_factory, context_factory): result = api.tag_api.get_tags( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == { 'query': '', 'page': 1, @@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges( api.tag_api.get_tags( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_retrieving_single(user_factory, tag_factory, context_factory): @@ -50,7 +50,7 @@ def test_retrieving_single(user_factory, tag_factory, context_factory): tags.serialize_tag.return_value = 'serialized tag' result = api.tag_api.get_tag( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) assert result == 'serialized tag' @@ -59,7 +59,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.get_tag( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': '-'}) @@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges( with pytest.raises(errors.AuthError): api.tag_api.get_tag( context_factory( - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'tag_name': '-'}) diff --git a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py index 6de25fcc..fc2f5aaa 100644 --- a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py @@ -1,12 +1,12 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'tags:view': model.User.RANK_REGULAR}}) def test_get_tag_siblings(user_factory, tag_factory, context_factory): @@ -21,7 +21,7 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory): (tag_factory(names=['sib2']), 3), ] result = api.tag_api.get_tag_siblings( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) assert result == { 'results': [ @@ -40,12 +40,12 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory): def test_trying_to_retrieve_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.get_tag_siblings( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': '-'}) def test_trying_to_retrieve_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): api.tag_api.get_tag_siblings( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'tag_name': '-'}) diff --git a/server/szurubooru/tests/api/test_tag_updating.py b/server/szurubooru/tests/api/test_tag_updating.py index 3fe69bd8..fb63e353 100644 --- a/server/szurubooru/tests/api/test_tag_updating.py +++ b/server/szurubooru/tests/api/test_tag_updating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags, snapshots @@ -8,18 +8,18 @@ from szurubooru.func import tags, snapshots def inject_config(config_injector): config_injector({ 'privileges': { - 'tags:create': db.User.RANK_REGULAR, - 'tags:edit:names': db.User.RANK_REGULAR, - 'tags:edit:category': db.User.RANK_REGULAR, - 'tags:edit:description': db.User.RANK_REGULAR, - 'tags:edit:suggestions': db.User.RANK_REGULAR, - 'tags:edit:implications': db.User.RANK_REGULAR, + 'tags:create': model.User.RANK_REGULAR, + 'tags:edit:names': model.User.RANK_REGULAR, + 'tags:edit:category': model.User.RANK_REGULAR, + 'tags:edit:description': model.User.RANK_REGULAR, + 'tags:edit:suggestions': model.User.RANK_REGULAR, + 'tags:edit:implications': model.User.RANK_REGULAR, }, }) def test_simple_updating(user_factory, tag_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) tag = tag_factory(names=['tag1', 'tag2']) db.session.add(tag) db.session.commit() @@ -56,8 +56,7 @@ def test_simple_updating(user_factory, tag_factory, context_factory): tag, ['sug1', 'sug2']) tags.update_tag_implications.assert_called_once_with( tag, ['imp1', 'imp2']) - tags.serialize_tag.assert_called_once_with( - tag, options=None) + tags.serialize_tag.assert_called_once_with(tag, options=[]) snapshots.modify.assert_called_once_with(tag, auth_user) tags.export_to_json.assert_called_once_with() @@ -90,7 +89,7 @@ def test_omitting_optional_field( api.tag_api.update_tag( context_factory( params={**params, **{'version': 1}}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) @@ -99,7 +98,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): api.tag_api.update_tag( context_factory( params={'names': ['dummy']}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag1'}) @@ -117,7 +116,7 @@ def test_trying_to_update_without_privileges( api.tag_api.update_tag( context_factory( params={**params, **{'version': 1}}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'tag_name': 'tag'}) @@ -127,9 +126,9 @@ def test_trying_to_create_tags_without_privileges( db.session.add(tag) db.session.commit() config_injector({'privileges': { - 'tags:create': db.User.RANK_ADMINISTRATOR, - 'tags:edit:suggestions': db.User.RANK_REGULAR, - 'tags:edit:implications': db.User.RANK_REGULAR, + 'tags:create': model.User.RANK_ADMINISTRATOR, + 'tags:edit:suggestions': model.User.RANK_REGULAR, + 'tags:edit:implications': model.User.RANK_REGULAR, }}) with patch('szurubooru.func.tags.get_or_create_tags_by_names'): tags.get_or_create_tags_by_names.return_value = ([], ['new-tag']) @@ -137,12 +136,12 @@ def test_trying_to_create_tags_without_privileges( api.tag_api.update_tag( context_factory( params={'suggestions': ['tag1', 'tag2'], 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) db.session.rollback() with pytest.raises(errors.AuthError): api.tag_api.update_tag( context_factory( params={'implications': ['tag1', 'tag2'], 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) diff --git a/server/szurubooru/tests/api/test_user_creating.py b/server/szurubooru/tests/api/test_user_creating.py index 8b583b6e..df2e80bb 100644 --- a/server/szurubooru/tests/api/test_user_creating.py +++ b/server/szurubooru/tests/api/test_user_creating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import users @@ -31,7 +31,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime): 'avatarStyle': 'manual', }, files={'avatar': b'...'}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == 'serialized user' users.create_user.assert_called_once_with( 'chewie1', 'oks', 'asd@asd.asd') @@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): 'password': 'oks', } user = user_factory() - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) del params[field] with patch('szurubooru.func.users.create_user'), \ pytest.raises(errors.MissingRequiredParameterError): @@ -70,7 +70,7 @@ def test_omitting_optional_field(user_factory, context_factory, field): } del params[field] user = user_factory() - auth_user = user_factory(rank=db.User.RANK_MODERATOR) + auth_user = user_factory(rank=model.User.RANK_MODERATOR) with patch('szurubooru.func.users.create_user'), \ patch('szurubooru.func.users.update_user_avatar'), \ patch('szurubooru.func.users.serialize_user'): @@ -84,4 +84,4 @@ def test_trying_to_create_user_without_privileges( with pytest.raises(errors.AuthError): api.user_api.create_user(context_factory( params='whatever', - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_user_deleting.py b/server/szurubooru/tests/api/test_user_deleting.py index 9dd87764..2bd53e2b 100644 --- a/server/szurubooru/tests/api/test_user_deleting.py +++ b/server/szurubooru/tests/api/test_user_deleting.py @@ -1,5 +1,5 @@ import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import users @@ -7,45 +7,45 @@ from szurubooru.func import users def inject_config(config_injector): config_injector({ 'privileges': { - 'users:delete:self': db.User.RANK_REGULAR, - 'users:delete:any': db.User.RANK_MODERATOR, + 'users:delete:self': model.User.RANK_REGULAR, + 'users:delete:any': model.User.RANK_MODERATOR, }, }) def test_deleting_oneself(user_factory, context_factory): - user = user_factory(name='u', rank=db.User.RANK_REGULAR) + user = user_factory(name='u', rank=model.User.RANK_REGULAR) db.session.add(user) db.session.commit() result = api.user_api.delete_user( context_factory( params={'version': 1}, user=user), {'user_name': 'u'}) assert result == {} - assert db.session.query(db.User).count() == 0 + assert db.session.query(model.User).count() == 0 def test_deleting_someone_else(user_factory, context_factory): - user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) - user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) + user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR) + user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR) db.session.add_all([user1, user2]) db.session.commit() api.user_api.delete_user( context_factory( params={'version': 1}, user=user2), {'user_name': 'u1'}) - assert db.session.query(db.User).count() == 1 + assert db.session.query(model.User).count() == 1 def test_trying_to_delete_someone_else_without_privileges( user_factory, context_factory): - user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) - user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) + user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR) + user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR) db.session.add_all([user1, user2]) db.session.commit() with pytest.raises(errors.AuthError): api.user_api.delete_user( context_factory( params={'version': 1}, user=user2), {'user_name': 'u1'}) - assert db.session.query(db.User).count() == 2 + assert db.session.query(model.User).count() == 2 def test_trying_to_delete_non_existing(user_factory, context_factory): @@ -53,5 +53,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory): api.user_api.delete_user( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'user_name': 'bad'}) diff --git a/server/szurubooru/tests/api/test_user_retrieving.py b/server/szurubooru/tests/api/test_user_retrieving.py index 6400e0d4..9be26200 100644 --- a/server/szurubooru/tests/api/test_user_retrieving.py +++ b/server/szurubooru/tests/api/test_user_retrieving.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import users @@ -8,16 +8,16 @@ from szurubooru.func import users def inject_config(config_injector): config_injector({ 'privileges': { - 'users:list': db.User.RANK_REGULAR, - 'users:view': db.User.RANK_REGULAR, - 'users:edit:any:email': db.User.RANK_MODERATOR, + 'users:list': model.User.RANK_REGULAR, + 'users:view': model.User.RANK_REGULAR, + 'users:edit:any:email': model.User.RANK_MODERATOR, }, }) def test_retrieving_multiple(user_factory, context_factory): - user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR) - user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) + user1 = user_factory(name='u1', rank=model.User.RANK_MODERATOR) + user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR) db.session.add_all([user1, user2]) db.session.flush() with patch('szurubooru.func.users.serialize_user'): @@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, context_factory): result = api.user_api.get_users( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == { 'query': '', 'page': 1, @@ -41,12 +41,12 @@ def test_trying_to_retrieve_multiple_without_privileges( api.user_api.get_users( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_retrieving_single(user_factory, context_factory): - user = user_factory(name='u1', rank=db.User.RANK_REGULAR) - auth_user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(name='u1', rank=model.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) db.session.add(user) db.session.flush() with patch('szurubooru.func.users.serialize_user'): @@ -57,7 +57,7 @@ def test_retrieving_single(user_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) with pytest.raises(users.UserNotFoundError): api.user_api.get_user( context_factory(user=auth_user), {'user_name': '-'}) @@ -65,8 +65,8 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_ANONYMOUS) - db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR)) + auth_user = user_factory(rank=model.User.RANK_ANONYMOUS) + db.session.add(user_factory(name='u1', rank=model.User.RANK_REGULAR)) db.session.flush() with pytest.raises(errors.AuthError): api.user_api.get_user( diff --git a/server/szurubooru/tests/api/test_user_updating.py b/server/szurubooru/tests/api/test_user_updating.py index 921b2697..af750493 100644 --- a/server/szurubooru/tests/api/test_user_updating.py +++ b/server/szurubooru/tests/api/test_user_updating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import users @@ -8,23 +8,23 @@ from szurubooru.func import users def inject_config(config_injector): config_injector({ 'privileges': { - 'users:edit:self:name': db.User.RANK_REGULAR, - 'users:edit:self:pass': db.User.RANK_REGULAR, - 'users:edit:self:email': db.User.RANK_REGULAR, - 'users:edit:self:rank': db.User.RANK_MODERATOR, - 'users:edit:self:avatar': db.User.RANK_MODERATOR, - 'users:edit:any:name': db.User.RANK_MODERATOR, - 'users:edit:any:pass': db.User.RANK_MODERATOR, - 'users:edit:any:email': db.User.RANK_MODERATOR, - 'users:edit:any:rank': db.User.RANK_ADMINISTRATOR, - 'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR, + 'users:edit:self:name': model.User.RANK_REGULAR, + 'users:edit:self:pass': model.User.RANK_REGULAR, + 'users:edit:self:email': model.User.RANK_REGULAR, + 'users:edit:self:rank': model.User.RANK_MODERATOR, + 'users:edit:self:avatar': model.User.RANK_MODERATOR, + 'users:edit:any:name': model.User.RANK_MODERATOR, + 'users:edit:any:pass': model.User.RANK_MODERATOR, + 'users:edit:any:email': model.User.RANK_MODERATOR, + 'users:edit:any:rank': model.User.RANK_ADMINISTRATOR, + 'users:edit:any:avatar': model.User.RANK_ADMINISTRATOR, }, }) def test_updating_user(context_factory, user_factory): - user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) - auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR) + user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR) + auth_user = user_factory(rank=model.User.RANK_ADMINISTRATOR) db.session.add(user) db.session.flush() @@ -63,13 +63,13 @@ def test_updating_user(context_factory, user_factory): users.update_user_avatar.assert_called_once_with( user, 'manual', b'...') users.serialize_user.assert_called_once_with( - user, auth_user, options=None) + user, auth_user, options=[]) @pytest.mark.parametrize( 'field', ['name', 'email', 'password', 'rank', 'avatarStyle']) def test_omitting_optional_field(user_factory, context_factory, field): - user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) + user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR) db.session.add(user) db.session.flush() params = { @@ -96,7 +96,7 @@ def test_omitting_optional_field(user_factory, context_factory, field): def test_trying_to_update_non_existing(user_factory, context_factory): - user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) + user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR) db.session.add(user) db.session.flush() with pytest.raises(users.UserNotFoundError): @@ -113,8 +113,8 @@ def test_trying_to_update_non_existing(user_factory, context_factory): ]) def test_trying_to_update_field_without_privileges( user_factory, context_factory, params): - user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) - user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) + user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR) + user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR) db.session.add_all([user1, user2]) db.session.flush() with pytest.raises(errors.AuthError): diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index db34ee02..e71f9609 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -7,8 +7,8 @@ from unittest.mock import patch from datetime import datetime import pytest import freezegun -import sqlalchemy -from szurubooru import config, db, rest +import sqlalchemy as sa +from szurubooru import config, db, model, rest class QueryCounter: @@ -36,10 +36,10 @@ if not config.config['test_database']: raise RuntimeError('Test database not configured.') _query_counter = QueryCounter() -_engine = sqlalchemy.create_engine(config.config['test_database']) -db.Base.metadata.drop_all(bind=_engine) -db.Base.metadata.create_all(bind=_engine) -sqlalchemy.event.listen( +_engine = sa.create_engine(config.config['test_database']) +model.Base.metadata.drop_all(bind=_engine) +model.Base.metadata.create_all(bind=_engine) +sa.event.listen( _engine, 'before_cursor_execute', _query_counter.create_before_cursor_execute()) @@ -79,14 +79,14 @@ def query_logger(): @pytest.yield_fixture(scope='function', autouse=True) def session(query_logger): # pylint: disable=unused-argument - db.sessionmaker = sqlalchemy.orm.sessionmaker( + db.sessionmaker = sa.orm.sessionmaker( bind=_engine, autoflush=False) - db.session = sqlalchemy.orm.scoped_session(db.sessionmaker) + db.session = sa.orm.scoped_session(db.sessionmaker) try: yield db.session finally: db.session.remove() - for table in reversed(db.Base.metadata.sorted_tables): + for table in reversed(model.Base.metadata.sorted_tables): db.session.execute(table.delete()) db.session.commit() @@ -101,7 +101,7 @@ def context_factory(session): params=params or {}, files=files or {}) ctx.session = session - ctx.user = user or db.User() + ctx.user = user or model.User() return ctx return factory @@ -115,15 +115,15 @@ def config_injector(): @pytest.fixture def user_factory(): - def factory(name=None, rank=db.User.RANK_REGULAR, email='dummy'): - user = db.User() + def factory(name=None, rank=model.User.RANK_REGULAR, email='dummy'): + user = model.User() user.name = name or get_unique_name() user.password_salt = 'dummy' user.password_hash = 'dummy' user.email = email user.rank = rank user.creation_time = datetime(1997, 1, 1) - user.avatar_style = db.User.AVATAR_GRAVATAR + user.avatar_style = model.User.AVATAR_GRAVATAR return user return factory @@ -131,7 +131,7 @@ def user_factory(): @pytest.fixture def tag_category_factory(): def factory(name=None, color='dummy', default=False): - category = db.TagCategory() + category = model.TagCategory() category.name = name or get_unique_name() category.color = color category.default = default @@ -143,12 +143,12 @@ def tag_category_factory(): def tag_factory(): def factory(names=None, category=None): if not category: - category = db.TagCategory(get_unique_name()) + category = model.TagCategory(get_unique_name()) db.session.add(category) - tag = db.Tag() + tag = model.Tag() tag.names = [] for i, name in enumerate(names or [get_unique_name()]): - tag.names.append(db.TagName(name, i)) + tag.names.append(model.TagName(name, i)) tag.category = category tag.creation_time = datetime(1996, 1, 1) return tag @@ -167,10 +167,10 @@ def post_factory(skip_post_hashing): # pylint: disable=invalid-name def factory( id=None, - safety=db.Post.SAFETY_SAFE, - type=db.Post.TYPE_IMAGE, + safety=model.Post.SAFETY_SAFE, + type=model.Post.TYPE_IMAGE, checksum='...'): - post = db.Post() + post = model.Post() post.post_id = id post.safety = safety post.type = type @@ -191,7 +191,7 @@ def comment_factory(user_factory, post_factory): if not post: post = post_factory() db.session.add(post) - comment = db.Comment() + comment = model.Comment() comment.user = user comment.post = post comment.text = text @@ -207,7 +207,7 @@ def post_score_factory(user_factory, post_factory): user = user_factory() if post is None: post = post_factory() - return db.PostScore( + return model.PostScore( post=post, user=user, score=score, time=datetime(1999, 1, 1)) return factory @@ -219,7 +219,7 @@ def post_favorite_factory(user_factory, post_factory): user = user_factory() if post is None: post = post_factory() - return db.PostFavorite( + return model.PostFavorite( post=post, user=user, time=datetime(1999, 1, 1)) return factory diff --git a/server/szurubooru/tests/func/test_comments.py b/server/szurubooru/tests/func/test_comments.py index c3c2fde1..f1e5d0f1 100644 --- a/server/szurubooru/tests/func/test_comments.py +++ b/server/szurubooru/tests/func/test_comments.py @@ -38,8 +38,6 @@ def test_try_get_comment(comment_factory): db.session.flush() assert comments.try_get_comment_by_id(comment.comment_id + 1) is None assert comments.try_get_comment_by_id(comment.comment_id) is comment - with pytest.raises(comments.InvalidCommentIdError): - comments.try_get_comment_by_id('-') def test_get_comment(comment_factory): @@ -49,8 +47,6 @@ def test_get_comment(comment_factory): with pytest.raises(comments.CommentNotFoundError): comments.get_comment_by_id(comment.comment_id + 1) assert comments.get_comment_by_id(comment.comment_id) is comment - with pytest.raises(comments.InvalidCommentIdError): - comments.get_comment_by_id('-') def test_create_comment(user_factory, post_factory, fake_datetime): diff --git a/server/szurubooru/tests/func/test_image_hash.py b/server/szurubooru/tests/func/test_image_hash.py index becba906..1b6efd21 100644 --- a/server/szurubooru/tests/func/test_image_hash.py +++ b/server/szurubooru/tests/func/test_image_hash.py @@ -2,7 +2,13 @@ from szurubooru.func import image_hash def test_hashing(read_asset, config_injector): - config_injector({'elasticsearch': {'index': 'szurubooru_test'}}) + config_injector({ + 'elasticsearch': { + 'host': 'localhost', + 'port': 9200, + 'index': 'szurubooru_test', + }, + }) image_hash.purge() image_hash.add_image('test', read_asset('jpeg.jpg')) diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 682a1ccc..76064699 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -2,7 +2,7 @@ import os from unittest.mock import patch from datetime import datetime import pytest -from szurubooru import db +from szurubooru import db, model from szurubooru.func import ( posts, users, comments, tags, images, files, util, image_hash) @@ -14,7 +14,7 @@ from szurubooru.func import ( ]) def test_get_post_url(input_mime_type, expected_url, config_injector): config_injector({'data_url': 'http://example.com/'}) - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_content_url(post) == expected_url @@ -23,7 +23,7 @@ def test_get_post_url(input_mime_type, expected_url, config_injector): @pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) def test_get_post_thumbnail_url(input_mime_type, config_injector): config_injector({'data_url': 'http://example.com/'}) - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_thumbnail_url(post) \ @@ -36,7 +36,7 @@ def test_get_post_thumbnail_url(input_mime_type, config_injector): ('totally/unknown', 'posts/1.dat'), ]) def test_get_post_content_path(input_mime_type, expected_path): - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_content_path(post) == expected_path @@ -44,7 +44,7 @@ def test_get_post_content_path(input_mime_type, expected_path): @pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) def test_get_post_thumbnail_path(input_mime_type): - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_thumbnail_path(post) == 'generated-thumbnails/1.jpg' @@ -52,7 +52,7 @@ def test_get_post_thumbnail_path(input_mime_type): @pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) def test_get_post_thumbnail_backup_path(input_mime_type): - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_thumbnail_backup_path(post) \ @@ -60,7 +60,7 @@ def test_get_post_thumbnail_backup_path(input_mime_type): def test_serialize_note(): - note = db.PostNote() + note = model.PostNote() note.polygon = [[0, 1], [1, 1], [1, 0], [0, 0]] note.text = '...' assert posts.serialize_note(note) == { @@ -86,7 +86,7 @@ def test_serialize_post( = lambda comment, auth_user: comment.user.name auth_user = user_factory(name='auth user') - post = db.Post() + post = model.Post() post.post_id = 1 post.creation_time = datetime(1997, 1, 1) post.last_edit_time = datetime(1998, 1, 1) @@ -94,9 +94,9 @@ def test_serialize_post( tag_factory(names=['tag1', 'tag2']), tag_factory(names=['tag3']) ] - post.safety = db.Post.SAFETY_SAFE + post.safety = model.Post.SAFETY_SAFE post.source = '4gag' - post.type = db.Post.TYPE_IMAGE + post.type = model.Post.TYPE_IMAGE post.checksum = 'deadbeef' post.mime_type = 'image/jpeg' post.file_size = 100 @@ -116,25 +116,25 @@ def test_serialize_post( user=user_factory(name='commenter2'), post=post, time=datetime(1999, 1, 2)), - db.PostFavorite( + model.PostFavorite( post=post, user=user_factory(name='fav1'), time=datetime(1800, 1, 1)), - db.PostFeature( + model.PostFeature( post=post, user=user_factory(), time=datetime(1999, 1, 1)), - db.PostScore( + model.PostScore( post=post, user=auth_user, score=-1, time=datetime(1800, 1, 1)), - db.PostScore( + model.PostScore( post=post, user=user_factory(), score=1, time=datetime(1800, 1, 1)), - db.PostScore( + model.PostScore( post=post, user=user_factory(), score=1, @@ -209,8 +209,6 @@ def test_try_get_post_by_id(post_factory): db.session.flush() assert posts.try_get_post_by_id(post.post_id) == post assert posts.try_get_post_by_id(post.post_id + 1) is None - with pytest.raises(posts.InvalidPostIdError): - posts.get_post_by_id('-') def test_get_post_by_id(post_factory): @@ -220,8 +218,6 @@ def test_get_post_by_id(post_factory): assert posts.get_post_by_id(post.post_id) == post with pytest.raises(posts.PostNotFoundError): posts.get_post_by_id(post.post_id + 1) - with pytest.raises(posts.InvalidPostIdError): - posts.get_post_by_id('-') def test_create_post(user_factory, fake_datetime): @@ -237,30 +233,30 @@ def test_create_post(user_factory, fake_datetime): @pytest.mark.parametrize('input_safety,expected_safety', [ - ('safe', db.Post.SAFETY_SAFE), - ('sketchy', db.Post.SAFETY_SKETCHY), - ('unsafe', db.Post.SAFETY_UNSAFE), + ('safe', model.Post.SAFETY_SAFE), + ('sketchy', model.Post.SAFETY_SKETCHY), + ('unsafe', model.Post.SAFETY_UNSAFE), ]) def test_update_post_safety(input_safety, expected_safety): - post = db.Post() + post = model.Post() posts.update_post_safety(post, input_safety) assert post.safety == expected_safety def test_update_post_safety_with_invalid_string(): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostSafetyError): posts.update_post_safety(post, 'bad') def test_update_post_source(): - post = db.Post() + post = model.Post() posts.update_post_source(post, 'x') assert post.source == 'x' def test_update_post_source_with_too_long_string(): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostSourceError): posts.update_post_source(post, 'x' * 1000) @@ -268,24 +264,24 @@ def test_update_post_source_with_too_long_string(): @pytest.mark.parametrize( 'is_existing,input_file,expected_mime_type,expected_type,output_file_name', [ - (True, 'png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), - (False, 'png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), - (False, 'jpeg.jpg', 'image/jpeg', db.Post.TYPE_IMAGE, '1.jpg'), - (False, 'gif.gif', 'image/gif', db.Post.TYPE_IMAGE, '1.gif'), + (True, 'png.png', 'image/png', model.Post.TYPE_IMAGE, '1.png'), + (False, 'png.png', 'image/png', model.Post.TYPE_IMAGE, '1.png'), + (False, 'jpeg.jpg', 'image/jpeg', model.Post.TYPE_IMAGE, '1.jpg'), + (False, 'gif.gif', 'image/gif', model.Post.TYPE_IMAGE, '1.gif'), ( False, 'gif-animated.gif', 'image/gif', - db.Post.TYPE_ANIMATION, + model.Post.TYPE_ANIMATION, '1.gif', ), - (False, 'webm.webm', 'video/webm', db.Post.TYPE_VIDEO, '1.webm'), - (False, 'mp4.mp4', 'video/mp4', db.Post.TYPE_VIDEO, '1.mp4'), + (False, 'webm.webm', 'video/webm', model.Post.TYPE_VIDEO, '1.webm'), + (False, 'mp4.mp4', 'video/mp4', model.Post.TYPE_VIDEO, '1.mp4'), ( False, 'flash.swf', 'application/x-shockwave-flash', - db.Post.TYPE_FLASH, + model.Post.TYPE_FLASH, '1.swf' ), ]) @@ -318,7 +314,7 @@ def test_update_post_content_for_new_post( assert post.type == expected_type assert post.checksum == 'crc' assert os.path.exists(output_file_path) - if post.type in (db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): + if post.type in (model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION): image_hash.delete_image.assert_called_once_with(post.post_id) image_hash.add_image.assert_called_once_with(post.post_id, content) else: @@ -368,7 +364,7 @@ def test_update_post_content_with_broken_content( @pytest.mark.parametrize('input_content', [None, b'not a media file']) def test_update_post_content_with_invalid_content(input_content): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostContentError): posts.update_post_content(post, input_content) @@ -492,7 +488,7 @@ def test_update_post_content_leaving_custom_thumbnail( def test_update_post_tags(tag_factory): - post = db.Post() + post = model.Post() with patch('szurubooru.func.tags.get_or_create_tags_by_names'): tags.get_or_create_tags_by_names.side_effect = lambda tag_names: \ ([tag_factory(names=[name]) for name in tag_names], []) @@ -528,7 +524,7 @@ def test_update_post_relations_bidirectionality(post_factory): def test_update_post_relations_with_nonexisting_posts(): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostRelationError): posts.update_post_relations(post, [100]) @@ -542,7 +538,7 @@ def test_update_post_relations_with_itself(post_factory): def test_update_post_notes(): - post = db.Post() + post = model.Post() posts.update_post_notes( post, [ @@ -576,19 +572,19 @@ def test_update_post_notes(): [{'polygon': [[0, 0], [0, 0], [0, 1]]}], ]) def test_update_post_notes_with_invalid_content(input): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostNoteError): posts.update_post_notes(post, input) def test_update_post_flags(): - post = db.Post() + post = model.Post() posts.update_post_flags(post, ['loop']) assert post.flags == ['loop'] def test_update_post_flags_with_invalid_content(): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostFlagError): posts.update_post_flags(post, ['invalid']) diff --git a/server/szurubooru/tests/func/test_snapshots.py b/server/szurubooru/tests/func/test_snapshots.py index d4c6754a..09491990 100644 --- a/server/szurubooru/tests/func/test_snapshots.py +++ b/server/szurubooru/tests/func/test_snapshots.py @@ -1,7 +1,7 @@ from unittest.mock import patch from datetime import datetime import pytest -from szurubooru import db +from szurubooru import db, model from szurubooru.func import snapshots, users @@ -56,20 +56,20 @@ def test_get_post_snapshot(post_factory, user_factory, tag_factory): db.session.add_all([user, tag1, tag2, post, related_post1, related_post2]) db.session.flush() - score = db.PostScore() + score = model.PostScore() score.post = post score.user = user score.time = datetime(1997, 1, 1) score.score = 1 - favorite = db.PostFavorite() + favorite = model.PostFavorite() favorite.post = post favorite.user = user favorite.time = datetime(1997, 1, 1) - feature = db.PostFeature() + feature = model.PostFeature() feature.post = post feature.user = user feature.time = datetime(1997, 1, 1) - note = db.PostNote() + note = model.PostNote() note.post = post note.polygon = [(1, 1), (200, 1), (200, 200), (1, 200)] note.text = 'some text' @@ -105,7 +105,7 @@ def test_get_post_snapshot(post_factory, user_factory, tag_factory): def test_serialize_snapshot(user_factory): auth_user = user_factory() - snapshot = db.Snapshot() + snapshot = model.Snapshot() snapshot.operation = snapshot.OPERATION_CREATED snapshot.resource_type = 'type' snapshot.resource_name = 'id' @@ -132,9 +132,9 @@ def test_create(tag_factory, user_factory): snapshots.get_tag_snapshot.return_value = 'mocked' snapshots.create(tag, user_factory()) db.session.flush() - results = db.session.query(db.Snapshot).all() + results = db.session.query(model.Snapshot).all() assert len(results) == 1 - assert results[0].operation == db.Snapshot.OPERATION_CREATED + assert results[0].operation == model.Snapshot.OPERATION_CREATED assert results[0].data == 'mocked' @@ -144,16 +144,16 @@ def test_modify_saves_non_empty_diffs(post_factory, user_factory): 'SQLite doesn\'t support transaction isolation, ' 'which is required to retrieve original entity') post = post_factory() - post.notes = [db.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text='old')] + post.notes = [model.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text='old')] user = user_factory() db.session.add_all([post, user]) db.session.commit() post.source = 'new source' - post.notes = [db.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text='new')] + post.notes = [model.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text='new')] db.session.flush() snapshots.modify(post, user) db.session.flush() - results = db.session.query(db.Snapshot).all() + results = db.session.query(model.Snapshot).all() assert len(results) == 1 assert results[0].data == { 'type': 'object change', @@ -181,7 +181,7 @@ def test_modify_doesnt_save_empty_diffs(tag_factory, user_factory): db.session.commit() snapshots.modify(tag, user) db.session.flush() - assert db.session.query(db.Snapshot).count() == 0 + assert db.session.query(model.Snapshot).count() == 0 def test_delete(tag_factory, user_factory): @@ -192,9 +192,9 @@ def test_delete(tag_factory, user_factory): snapshots.get_tag_snapshot.return_value = 'mocked' snapshots.delete(tag, user_factory()) db.session.flush() - results = db.session.query(db.Snapshot).all() + results = db.session.query(model.Snapshot).all() assert len(results) == 1 - assert results[0].operation == db.Snapshot.OPERATION_DELETED + assert results[0].operation == model.Snapshot.OPERATION_DELETED assert results[0].data == 'mocked' @@ -205,6 +205,6 @@ def test_merge(tag_factory, user_factory): db.session.flush() snapshots.merge(source_tag, target_tag, user_factory()) db.session.flush() - result = db.session.query(db.Snapshot).one() - assert result.operation == db.Snapshot.OPERATION_MERGED + result = db.session.query(model.Snapshot).one() + assert result.operation == model.Snapshot.OPERATION_MERGED assert result.data == ['tag', 'target'] diff --git a/server/szurubooru/tests/func/test_tag_categories.py b/server/szurubooru/tests/func/test_tag_categories.py index cf74c2a5..d1e55709 100644 --- a/server/szurubooru/tests/func/test_tag_categories.py +++ b/server/szurubooru/tests/func/test_tag_categories.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import db +from szurubooru import db, model from szurubooru.func import tag_categories, cache @@ -191,7 +191,7 @@ def test_get_default_category_name(tag_category_factory): db.session.flush() cache.purge() assert tag_categories.get_default_category_name() == category1.name - db.session.query(db.TagCategory).delete() + db.session.query(model.TagCategory).delete() cache.purge() with pytest.raises(tag_categories.TagCategoryNotFoundError): tag_categories.get_default_category_name() diff --git a/server/szurubooru/tests/func/test_tags.py b/server/szurubooru/tests/func/test_tags.py index d4674998..712c8e38 100644 --- a/server/szurubooru/tests/func/test_tags.py +++ b/server/szurubooru/tests/func/test_tags.py @@ -3,7 +3,7 @@ import json from unittest.mock import patch from datetime import datetime import pytest -from szurubooru import db +from szurubooru import db, model from szurubooru.func import tags, tag_categories, cache @@ -304,10 +304,10 @@ def test_delete(tag_factory): tag.implications = [tag_factory(names=['imp'])] db.session.add(tag) db.session.flush() - assert db.session.query(db.Tag).count() == 3 + assert db.session.query(model.Tag).count() == 3 tags.delete(tag) db.session.flush() - assert db.session.query(db.Tag).count() == 2 + assert db.session.query(model.Tag).count() == 2 def test_merge_tags_deletes_source_tag(tag_factory): diff --git a/server/szurubooru/tests/func/test_users.py b/server/szurubooru/tests/func/test_users.py index 73150bb2..53d47de6 100644 --- a/server/szurubooru/tests/func/test_users.py +++ b/server/szurubooru/tests/func/test_users.py @@ -1,7 +1,7 @@ from unittest.mock import patch from datetime import datetime import pytest -from szurubooru import db, errors +from szurubooru import db, model, errors from szurubooru.func import auth, users, files, util @@ -20,28 +20,28 @@ def test_get_avatar_path(user_name): ( 'user', None, - db.User.AVATAR_GRAVATAR, + model.User.AVATAR_GRAVATAR, ('https://gravatar.com/avatar/' + 'ee11cbb19052e40b07aac0ca060c23ee?d=retro&s=100'), ), ( None, 'user@example.com', - db.User.AVATAR_GRAVATAR, + model.User.AVATAR_GRAVATAR, ('https://gravatar.com/avatar/' + 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100'), ), ( 'user', 'user@example.com', - db.User.AVATAR_GRAVATAR, + model.User.AVATAR_GRAVATAR, ('https://gravatar.com/avatar/' + 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100'), ), ( 'user', None, - db.User.AVATAR_MANUAL, + model.User.AVATAR_MANUAL, 'http://example.com/avatars/user.png', ), ]) @@ -51,7 +51,7 @@ def test_get_avatar_url( 'data_url': 'http://example.com/', 'thumbnails': {'avatar_width': 100}, }) - user = db.User() + user = model.User() user.name = user_name user.email = user_email user.avatar_style = avatar_style @@ -100,7 +100,7 @@ def test_get_liked_post_count( user = user_factory() post = post_factory() auth_user = user if same_user else user_factory() - score = db.PostScore( + score = model.PostScore( post=post, user=user, score=score, time=datetime.now()) db.session.add_all([post, user, score]) db.session.flush() @@ -127,8 +127,8 @@ def test_serialize_user(user_factory): user = user_factory(name='dummy user') user.creation_time = datetime(1997, 1, 1) user.last_edit_time = datetime(1998, 1, 1) - user.avatar_style = db.User.AVATAR_MANUAL - user.rank = db.User.RANK_ADMINISTRATOR + user.avatar_style = model.User.AVATAR_MANUAL + user.rank = model.User.RANK_ADMINISTRATOR db.session.add(user) db.session.flush() assert users.serialize_user(user, auth_user) == { @@ -222,7 +222,7 @@ def test_create_user_for_first_user(fake_datetime): user = users.create_user('name', 'password', 'email') assert user.creation_time == datetime(1997, 1, 1) assert user.last_login_time is None - assert user.rank == db.User.RANK_ADMINISTRATOR + assert user.rank == model.User.RANK_ADMINISTRATOR users.update_user_name.assert_called_once_with(user, 'name') users.update_user_password.assert_called_once_with(user, 'password') users.update_user_email.assert_called_once_with(user, 'email') @@ -236,7 +236,7 @@ def test_create_user_for_subsequent_users(user_factory, config_injector): patch('szurubooru.func.users.update_user_email'), \ patch('szurubooru.func.users.update_user_password'): user = users.create_user('name', 'password', 'email') - assert user.rank == db.User.RANK_REGULAR + assert user.rank == model.User.RANK_REGULAR def test_update_user_name_with_empty_string(user_factory): @@ -379,7 +379,7 @@ def test_update_user_rank_with_higher_rank_than_possible(user_factory): db.session.flush() user = user_factory() auth_user = user_factory() - auth_user.rank = db.User.RANK_ANONYMOUS + auth_user.rank = model.User.RANK_ANONYMOUS with pytest.raises(errors.AuthError): users.update_user_rank(user, 'regular', auth_user) with pytest.raises(errors.AuthError): @@ -391,11 +391,11 @@ def test_update_user_rank(user_factory): db.session.flush() user = user_factory() auth_user = user_factory() - auth_user.rank = db.User.RANK_ADMINISTRATOR + auth_user.rank = model.User.RANK_ADMINISTRATOR users.update_user_rank(user, 'regular', auth_user) users.update_user_rank(auth_user, 'regular', auth_user) - assert user.rank == db.User.RANK_REGULAR - assert auth_user.rank == db.User.RANK_REGULAR + assert user.rank == model.User.RANK_REGULAR + assert auth_user.rank == model.User.RANK_REGULAR def test_update_user_avatar_with_invalid_style(user_factory): @@ -407,7 +407,7 @@ def test_update_user_avatar_with_invalid_style(user_factory): def test_update_user_avatar_to_gravatar(user_factory): user = user_factory() users.update_user_avatar(user, 'gravatar') - assert user.avatar_style == db.User.AVATAR_GRAVATAR + assert user.avatar_style == model.User.AVATAR_GRAVATAR def test_update_user_avatar_to_empty_manual(user_factory): @@ -431,7 +431,7 @@ def test_update_user_avatar_to_new_manual(user_factory, config_injector): user = user_factory() with patch('szurubooru.func.files.save'): users.update_user_avatar(user, 'manual', EMPTY_PIXEL) - assert user.avatar_style == db.User.AVATAR_MANUAL + assert user.avatar_style == model.User.AVATAR_MANUAL assert files.save.called diff --git a/server/szurubooru/tests/db/__init__.py b/server/szurubooru/tests/model/__init__.py similarity index 100% rename from server/szurubooru/tests/db/__init__.py rename to server/szurubooru/tests/model/__init__.py diff --git a/server/szurubooru/tests/db/test_comment.py b/server/szurubooru/tests/model/test_comment.py similarity index 74% rename from server/szurubooru/tests/db/test_comment.py rename to server/szurubooru/tests/model/test_comment.py index 9a78f952..ffd51893 100644 --- a/server/szurubooru/tests/db/test_comment.py +++ b/server/szurubooru/tests/model/test_comment.py @@ -1,11 +1,11 @@ from datetime import datetime -from szurubooru import db +from szurubooru import db, model def test_saving_comment(user_factory, post_factory): user = user_factory() post = post_factory() - comment = db.Comment() + comment = model.Comment() comment.text = 'long text' * 1000 comment.user = user comment.post = post @@ -29,7 +29,7 @@ def test_cascade_deletions(comment_factory, user_factory, post_factory): db.session.add_all([user, comment]) db.session.flush() - score = db.CommentScore() + score = model.CommentScore() score.comment = comment score.user = user score.time = datetime(1997, 1, 1) @@ -39,14 +39,14 @@ def test_cascade_deletions(comment_factory, user_factory, post_factory): assert not db.session.dirty assert comment.user is not None and comment.user.user_id is not None - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Comment).count() == 1 - assert db.session.query(db.CommentScore).count() == 1 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Comment).count() == 1 + assert db.session.query(model.CommentScore).count() == 1 db.session.delete(comment) db.session.commit() assert not db.session.dirty - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Comment).count() == 0 - assert db.session.query(db.CommentScore).count() == 0 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Comment).count() == 0 + assert db.session.query(model.CommentScore).count() == 0 diff --git a/server/szurubooru/tests/db/test_post.py b/server/szurubooru/tests/model/test_post.py similarity index 71% rename from server/szurubooru/tests/db/test_post.py rename to server/szurubooru/tests/model/test_post.py index c0213535..f35e2751 100644 --- a/server/szurubooru/tests/db/test_post.py +++ b/server/szurubooru/tests/model/test_post.py @@ -1,5 +1,5 @@ from datetime import datetime -from szurubooru import db +from szurubooru import db, model def test_saving_post(post_factory, user_factory, tag_factory): @@ -8,7 +8,7 @@ def test_saving_post(post_factory, user_factory, tag_factory): tag2 = tag_factory() related_post1 = post_factory() related_post2 = post_factory() - post = db.Post() + post = model.Post() post.safety = 'safety' post.type = 'type' post.checksum = 'deadbeef' @@ -54,20 +54,20 @@ def test_cascade_deletions( user, tag1, tag2, post, related_post1, related_post2, comment]) db.session.flush() - score = db.PostScore() + score = model.PostScore() score.post = post score.user = user score.time = datetime(1997, 1, 1) score.score = 1 - favorite = db.PostFavorite() + favorite = model.PostFavorite() favorite.post = post favorite.user = user favorite.time = datetime(1997, 1, 1) - feature = db.PostFeature() + feature = model.PostFeature() feature.post = post feature.user = user feature.time = datetime(1997, 1, 1) - note = db.PostNote() + note = model.PostNote() note.post = post note.polygon = '' note.text = '' @@ -88,31 +88,31 @@ def test_cascade_deletions( assert not db.session.dirty assert post.user is not None and post.user.user_id is not None assert len(post.relations) == 1 - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Tag).count() == 2 - assert db.session.query(db.Post).count() == 3 - assert db.session.query(db.PostTag).count() == 2 - assert db.session.query(db.PostRelation).count() == 2 - assert db.session.query(db.PostScore).count() == 1 - assert db.session.query(db.PostNote).count() == 1 - assert db.session.query(db.PostFeature).count() == 1 - assert db.session.query(db.PostFavorite).count() == 1 - assert db.session.query(db.Comment).count() == 1 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Tag).count() == 2 + assert db.session.query(model.Post).count() == 3 + assert db.session.query(model.PostTag).count() == 2 + assert db.session.query(model.PostRelation).count() == 2 + assert db.session.query(model.PostScore).count() == 1 + assert db.session.query(model.PostNote).count() == 1 + assert db.session.query(model.PostFeature).count() == 1 + assert db.session.query(model.PostFavorite).count() == 1 + assert db.session.query(model.Comment).count() == 1 db.session.delete(post) db.session.commit() assert not db.session.dirty - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Tag).count() == 2 - assert db.session.query(db.Post).count() == 2 - assert db.session.query(db.PostTag).count() == 0 - assert db.session.query(db.PostRelation).count() == 0 - assert db.session.query(db.PostScore).count() == 0 - assert db.session.query(db.PostNote).count() == 0 - assert db.session.query(db.PostFeature).count() == 0 - assert db.session.query(db.PostFavorite).count() == 0 - assert db.session.query(db.Comment).count() == 0 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Tag).count() == 2 + assert db.session.query(model.Post).count() == 2 + assert db.session.query(model.PostTag).count() == 0 + assert db.session.query(model.PostRelation).count() == 0 + assert db.session.query(model.PostScore).count() == 0 + assert db.session.query(model.PostNote).count() == 0 + assert db.session.query(model.PostFeature).count() == 0 + assert db.session.query(model.PostFavorite).count() == 0 + assert db.session.query(model.Comment).count() == 0 def test_tracking_tag_count(post_factory, tag_factory): diff --git a/server/szurubooru/tests/db/test_tag.py b/server/szurubooru/tests/model/test_tag.py similarity index 80% rename from server/szurubooru/tests/db/test_tag.py rename to server/szurubooru/tests/model/test_tag.py index 02134d69..7d3d8d2f 100644 --- a/server/szurubooru/tests/db/test_tag.py +++ b/server/szurubooru/tests/model/test_tag.py @@ -1,5 +1,5 @@ from datetime import datetime -from szurubooru import db +from szurubooru import db, model def test_saving_tag(tag_factory): @@ -7,11 +7,11 @@ def test_saving_tag(tag_factory): sug2 = tag_factory(names=['sug2']) imp1 = tag_factory(names=['imp1']) imp2 = tag_factory(names=['imp2']) - tag = db.Tag() - tag.names = [db.TagName('alias1', 0), db.TagName('alias2', 1)] + tag = model.Tag() + tag.names = [model.TagName('alias1', 0), model.TagName('alias2', 1)] tag.suggestions = [] tag.implications = [] - tag.category = db.TagCategory('category') + tag.category = model.TagCategory('category') tag.creation_time = datetime(1997, 1, 1) tag.last_edit_time = datetime(1998, 1, 1) db.session.add_all([tag, sug1, sug2, imp1, imp2]) @@ -29,9 +29,9 @@ def test_saving_tag(tag_factory): db.session.commit() tag = db.session \ - .query(db.Tag) \ - .join(db.TagName) \ - .filter(db.TagName.name == 'alias1') \ + .query(model.Tag) \ + .join(model.TagName) \ + .filter(model.TagName.name == 'alias1') \ .one() assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2'] assert tag.category.name == 'category' @@ -48,11 +48,11 @@ def test_cascade_deletions(tag_factory): sug2 = tag_factory(names=['sug2']) imp1 = tag_factory(names=['imp1']) imp2 = tag_factory(names=['imp2']) - tag = db.Tag() - tag.names = [db.TagName('alias1', 0), db.TagName('alias2', 1)] + tag = model.Tag() + tag.names = [model.TagName('alias1', 0), model.TagName('alias2', 1)] tag.suggestions = [] tag.implications = [] - tag.category = db.TagCategory('category') + tag.category = model.TagCategory('category') tag.creation_time = datetime(1997, 1, 1) tag.last_edit_time = datetime(1998, 1, 1) tag.post_count = 1 @@ -72,10 +72,10 @@ def test_cascade_deletions(tag_factory): db.session.delete(tag) db.session.commit() - assert db.session.query(db.Tag).count() == 4 - assert db.session.query(db.TagName).count() == 4 - assert db.session.query(db.TagImplication).count() == 0 - assert db.session.query(db.TagSuggestion).count() == 0 + assert db.session.query(model.Tag).count() == 4 + assert db.session.query(model.TagName).count() == 4 + assert db.session.query(model.TagImplication).count() == 0 + assert db.session.query(model.TagSuggestion).count() == 0 def test_tracking_post_count(post_factory, tag_factory): diff --git a/server/szurubooru/tests/db/test_user.py b/server/szurubooru/tests/model/test_user.py similarity index 66% rename from server/szurubooru/tests/db/test_user.py rename to server/szurubooru/tests/model/test_user.py index 59933e36..08875fa2 100644 --- a/server/szurubooru/tests/db/test_user.py +++ b/server/szurubooru/tests/model/test_user.py @@ -1,16 +1,16 @@ from datetime import datetime -from szurubooru import db +from szurubooru import db, model def test_saving_user(): - user = db.User() + user = model.User() user.name = 'name' user.password_salt = 'salt' user.password_hash = 'hash' user.email = 'email' user.rank = 'rank' user.creation_time = datetime(1997, 1, 1) - user.avatar_style = db.User.AVATAR_GRAVATAR + user.avatar_style = model.User.AVATAR_GRAVATAR db.session.add(user) db.session.flush() db.session.refresh(user) @@ -21,7 +21,7 @@ def test_saving_user(): assert user.email == 'email' assert user.rank == 'rank' assert user.creation_time == datetime(1997, 1, 1) - assert user.avatar_style == db.User.AVATAR_GRAVATAR + assert user.avatar_style == model.User.AVATAR_GRAVATAR def test_upload_count(user_factory, post_factory): @@ -61,8 +61,8 @@ def test_favorite_count(user_factory, post_factory): post1 = post_factory() post2 = post_factory() db.session.add_all([ - db.PostFavorite(post=post1, time=datetime.utcnow(), user=user1), - db.PostFavorite(post=post2, time=datetime.utcnow(), user=user2), + model.PostFavorite(post=post1, time=datetime.utcnow(), user=user1), + model.PostFavorite(post=post2, time=datetime.utcnow(), user=user2), ]) db.session.flush() db.session.refresh(user1) @@ -79,8 +79,10 @@ def test_liked_post_count(user_factory, post_factory): post1 = post_factory() post2 = post_factory() db.session.add_all([ - db.PostScore(post=post1, time=datetime.utcnow(), user=user1, score=1), - db.PostScore(post=post2, time=datetime.utcnow(), user=user2, score=1), + model.PostScore( + post=post1, time=datetime.utcnow(), user=user1, score=1), + model.PostScore( + post=post2, time=datetime.utcnow(), user=user2, score=1), ]) db.session.flush() db.session.refresh(user1) @@ -98,8 +100,10 @@ def test_disliked_post_count(user_factory, post_factory): post1 = post_factory() post2 = post_factory() db.session.add_all([ - db.PostScore(post=post1, time=datetime.utcnow(), user=user1, score=-1), - db.PostScore(post=post2, time=datetime.utcnow(), user=user2, score=1), + model.PostScore( + post=post1, time=datetime.utcnow(), user=user1, score=-1), + model.PostScore( + post=post2, time=datetime.utcnow(), user=user2, score=1), ]) db.session.flush() db.session.refresh(user1) @@ -114,34 +118,34 @@ def test_cascade_deletions(post_factory, user_factory, comment_factory): post = post_factory() post.user = user - post_score = db.PostScore() + post_score = model.PostScore() post_score.post = post post_score.user = user post_score.time = datetime(1997, 1, 1) post_score.score = 1 post.scores.append(post_score) - post_favorite = db.PostFavorite() + post_favorite = model.PostFavorite() post_favorite.post = post post_favorite.user = user post_favorite.time = datetime(1997, 1, 1) post.favorited_by.append(post_favorite) - post_feature = db.PostFeature() + post_feature = model.PostFeature() post_feature.post = post post_feature.user = user post_feature.time = datetime(1997, 1, 1) post.features.append(post_feature) comment = comment_factory(post=post, user=user) - comment_score = db.CommentScore() + comment_score = model.CommentScore() comment_score.comment = comment comment_score.user = user comment_score.time = datetime(1997, 1, 1) comment_score.score = 1 comment.scores.append(comment_score) - snapshot = db.Snapshot() + snapshot = model.Snapshot() snapshot.user = user snapshot.creation_time = datetime(1997, 1, 1) snapshot.resource_type = '-' @@ -154,27 +158,27 @@ def test_cascade_deletions(post_factory, user_factory, comment_factory): assert not db.session.dirty assert post.user is not None and post.user.user_id is not None - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Post).count() == 1 - assert db.session.query(db.PostScore).count() == 1 - assert db.session.query(db.PostFeature).count() == 1 - assert db.session.query(db.PostFavorite).count() == 1 - assert db.session.query(db.Comment).count() == 1 - assert db.session.query(db.CommentScore).count() == 1 - assert db.session.query(db.Snapshot).count() == 1 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Post).count() == 1 + assert db.session.query(model.PostScore).count() == 1 + assert db.session.query(model.PostFeature).count() == 1 + assert db.session.query(model.PostFavorite).count() == 1 + assert db.session.query(model.Comment).count() == 1 + assert db.session.query(model.CommentScore).count() == 1 + assert db.session.query(model.Snapshot).count() == 1 db.session.delete(user) db.session.commit() assert not db.session.dirty - assert db.session.query(db.User).count() == 0 - assert db.session.query(db.Post).count() == 1 - assert db.session.query(db.Post)[0].user is None - assert db.session.query(db.PostScore).count() == 0 - assert db.session.query(db.PostFeature).count() == 0 - assert db.session.query(db.PostFavorite).count() == 0 - assert db.session.query(db.Comment).count() == 1 - assert db.session.query(db.Comment)[0].user is None - assert db.session.query(db.CommentScore).count() == 0 - assert db.session.query(db.Snapshot).count() == 1 - assert db.session.query(db.Snapshot)[0].user is None + assert db.session.query(model.User).count() == 0 + assert db.session.query(model.Post).count() == 1 + assert db.session.query(model.Post)[0].user is None + assert db.session.query(model.PostScore).count() == 0 + assert db.session.query(model.PostFeature).count() == 0 + assert db.session.query(model.PostFavorite).count() == 0 + assert db.session.query(model.Comment).count() == 1 + assert db.session.query(model.Comment)[0].user is None + assert db.session.query(model.CommentScore).count() == 0 + assert db.session.query(model.Snapshot).count() == 1 + assert db.session.query(model.Snapshot)[0].user is None diff --git a/server/szurubooru/tests/rest/test_context.py b/server/szurubooru/tests/rest/test_context.py index 7380a855..e112ebbe 100644 --- a/server/szurubooru/tests/rest/test_context.py +++ b/server/szurubooru/tests/rest/test_context.py @@ -8,13 +8,14 @@ from szurubooru.func import net def test_has_param(): ctx = rest.Context(method=None, url=None, params={'key': 'value'}) assert ctx.has_param('key') - assert not ctx.has_param('key2') + assert not ctx.has_param('non-existing') def test_get_file(): ctx = rest.Context(method=None, url=None, files={'key': b'content'}) assert ctx.get_file('key') == b'content' - assert ctx.get_file('key2') is None + with pytest.raises(errors.ValidationError): + ctx.get_file('non-existing') def test_get_file_from_url(): @@ -23,30 +24,33 @@ def test_get_file_from_url(): ctx = rest.Context( method=None, url=None, params={'keyUrl': 'example.com'}) assert ctx.get_file('key') == b'content' - assert ctx.get_file('key2') is None net.download.assert_called_once_with('example.com') + with pytest.raises(errors.ValidationError): + assert ctx.get_file('non-existing') def test_getting_list_parameter(): ctx = rest.Context( - method=None, url=None, params={'key': 'value', 'list': list('123')}) + method=None, + url=None, + params={'key': 'value', 'list': ['1', '2', '3']}) assert ctx.get_param_as_list('key') == ['value'] - assert ctx.get_param_as_list('key2') is None - assert ctx.get_param_as_list('key2', default=['def']) == ['def'] assert ctx.get_param_as_list('list') == ['1', '2', '3'] with pytest.raises(errors.ValidationError): - ctx.get_param_as_list('key2', required=True) + ctx.get_param_as_list('non-existing') + assert ctx.get_param_as_list('non-existing', default=['def']) == ['def'] def test_getting_string_parameter(): ctx = rest.Context( - method=None, url=None, params={'key': 'value', 'list': list('123')}) + method=None, + url=None, + params={'key': 'value', 'list': ['1', '2', '3']}) assert ctx.get_param_as_string('key') == 'value' - assert ctx.get_param_as_string('key2') is None - assert ctx.get_param_as_string('key2', default='def') == 'def' assert ctx.get_param_as_string('list') == '1,2,3' with pytest.raises(errors.ValidationError): - ctx.get_param_as_string('key2', required=True) + ctx.get_param_as_string('non-existing') + assert ctx.get_param_as_string('non-existing', default='x') == 'x' def test_getting_int_parameter(): @@ -55,12 +59,11 @@ def test_getting_int_parameter(): url=None, params={'key': '50', 'err': 'invalid', 'list': [1, 2, 3]}) assert ctx.get_param_as_int('key') == 50 - assert ctx.get_param_as_int('key2') is None - assert ctx.get_param_as_int('key2', default=5) == 5 with pytest.raises(errors.ValidationError): ctx.get_param_as_int('list') with pytest.raises(errors.ValidationError): - ctx.get_param_as_int('key2', required=True) + ctx.get_param_as_int('non-existing') + assert ctx.get_param_as_int('non-existing', default=5) == 5 with pytest.raises(errors.ValidationError): ctx.get_param_as_int('err') with pytest.raises(errors.ValidationError): @@ -102,7 +105,6 @@ def test_getting_bool_parameter(): test(['1', '2']) ctx = rest.Context(method=None, url=None) - assert ctx.get_param_as_bool('non-existing') is None - assert ctx.get_param_as_bool('non-existing', default=True) is True with pytest.raises(errors.ValidationError): - assert ctx.get_param_as_bool('non-existing', required=True) is None + ctx.get_param_as_bool('non-existing') + assert ctx.get_param_as_bool('non-existing', default=True) is True diff --git a/server/szurubooru/tests/search/configs/test_post_search_config.py b/server/szurubooru/tests/search/configs/test_post_search_config.py index d5796779..945a5e4f 100644 --- a/server/szurubooru/tests/search/configs/test_post_search_config.py +++ b/server/szurubooru/tests/search/configs/test_post_search_config.py @@ -1,13 +1,13 @@ # pylint: disable=redefined-outer-name from datetime import datetime import pytest -from szurubooru import db, errors, search +from szurubooru import db, model, errors, search @pytest.fixture def fav_factory(user_factory): def factory(post, user=None): - return db.PostFavorite( + return model.PostFavorite( post=post, user=user or user_factory(), time=datetime.utcnow()) @@ -17,7 +17,7 @@ def fav_factory(user_factory): @pytest.fixture def score_factory(user_factory): def factory(post, user=None, score=1): - return db.PostScore( + return model.PostScore( post=post, user=user or user_factory(), time=datetime.utcnow(), @@ -28,7 +28,7 @@ def score_factory(user_factory): @pytest.fixture def note_factory(): def factory(): - return db.PostNote(polygon='...', text='...') + return model.PostNote(polygon='...', text='...') return factory @@ -36,11 +36,11 @@ def note_factory(): def feature_factory(user_factory): def factory(post=None): if post: - return db.PostFeature( + return model.PostFeature( time=datetime.utcnow(), user=user_factory(), post=post) - return db.PostFeature( + return model.PostFeature( time=datetime.utcnow(), user=user_factory()) return factory @@ -123,7 +123,7 @@ def test_filter_by_score( post3 = post_factory(id=3) for post in [post1, post2, post3]: db.session.add( - db.PostScore( + model.PostScore( score=post.post_id, time=datetime.utcnow(), post=post, @@ -332,10 +332,10 @@ def test_filter_by_type( post2 = post_factory(id=2) post3 = post_factory(id=3) post4 = post_factory(id=4) - post1.type = db.Post.TYPE_IMAGE - post2.type = db.Post.TYPE_ANIMATION - post3.type = db.Post.TYPE_VIDEO - post4.type = db.Post.TYPE_FLASH + post1.type = model.Post.TYPE_IMAGE + post2.type = model.Post.TYPE_ANIMATION + post3.type = model.Post.TYPE_VIDEO + post4.type = model.Post.TYPE_FLASH db.session.add_all([post1, post2, post3, post4]) db.session.flush() verify_unpaged(input, expected_post_ids) @@ -352,9 +352,9 @@ def test_filter_by_safety( post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) - post1.safety = db.Post.SAFETY_SAFE - post2.safety = db.Post.SAFETY_SKETCHY - post3.safety = db.Post.SAFETY_UNSAFE + post1.safety = model.Post.SAFETY_SAFE + post2.safety = model.Post.SAFETY_SKETCHY + post3.safety = model.Post.SAFETY_UNSAFE db.session.add_all([post1, post2, post3]) db.session.flush() verify_unpaged(input, expected_post_ids) diff --git a/server/test b/server/test index 6d7bb6de..69cfe542 100755 --- a/server/test +++ b/server/test @@ -4,4 +4,5 @@ import sys pytest.main([ '--cov-report=term-missing', '--cov=szurubooru', + '--tb=short', ] + (sys.argv[1:] or ['szurubooru'])) From e49008034787b16b5a5928e5a1948cf792b8f028 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 15:23:32 +0100 Subject: [PATCH 029/159] server/scripts: remove migration script It was unmaintained for months (years?) anyway --- server/migrate-v1 | 387 ---------------------------------------------- 1 file changed, 387 deletions(-) delete mode 100755 server/migrate-v1 diff --git a/server/migrate-v1 b/server/migrate-v1 deleted file mode 100755 index d3ec0dda..00000000 --- a/server/migrate-v1 +++ /dev/null @@ -1,387 +0,0 @@ -#!/usr/bin/env python3 -import os -import sys -import datetime -import argparse -import json -import zlib -import concurrent.futures -import logging -import coloredlogs -import sqlalchemy as sa -from szurubooru import config, db -from szurubooru.func import files, images, posts, comments - -coloredlogs.install(fmt='[%(asctime)-15s] %(message)s') -logger = logging.getLogger(__name__) - -def read_file(path): - with open(path, 'rb') as handle: - return handle.read() - -def translate_note_polygon(row): - x, y = row['x'], row['y'] - w, h = row.get('width', row.get('w')), row.get('height', row.get('h')) - x /= 100.0 - y /= 100.0 - w /= 100.0 - h /= 100.0 - return [ - (x, y ), - (x + w, y ), - (x + w, y + h), - (x, y + h), - ] - -def get_v1_session(args): - dsn = '{schema}://{user}:{password}@{host}:{port}/{name}?charset=utf8'.format( - schema='mysql+pymysql', - user=args.user, - password=args.password, - host=args.host, - port=args.port, - name=args.name) - logger.info('Connecting to %r...', dsn) - engine = sa.create_engine(dsn) - session_maker = sa.orm.sessionmaker(bind=engine) - return session_maker() - -def parse_args(): - parser = argparse.ArgumentParser( - description='Migrate database from szurubooru v1.x to v2.x.\n\n') - - parser.add_argument( - '--data-dir', - metavar='PATH', required=True, - help='root directory of v1.x deploy') - - parser.add_argument( - '--host', - metavar='HOST', required=True, - help='name of v1.x database host') - - parser.add_argument( - '--port', - metavar='NUM', type=int, default=3306, - help='port to v1.x database host') - - parser.add_argument( - '--name', - metavar='HOST', required=True, - help='v1.x database name') - - parser.add_argument( - '--user', - metavar='NAME', required=True, - help='v1.x database user') - - parser.add_argument( - '--password', - metavar='PASSWORD', required=True, - help='v1.x database password') - - parser.add_argument( - '--no-data', - action='store_true', - help='don\'t migrate post data') - - return parser.parse_args() - -def exec_query(session, query): - for row in list(session.execute(query)): - row = dict(zip(row.keys(), row)) - yield row - -def exec_scalar(session, query): - rows = list(exec_query(session, query)) - first_row = rows[0] - return list(first_row.values())[0] - -def import_users(v1_data_dir, v1_session, v2_session, no_data=False): - for row in exec_query(v1_session, 'SELECT * FROM users'): - logger.info('Importing user %s...', row['name']) - user = db.User() - user.user_id = row['id'] - user.name = row['name'] - user.password_salt = row['passwordSalt'] - user.password_hash = row['passwordHash'] - user.email = row['email'] - user.rank = { - 6: db.User.RANK_ADMINISTRATOR, - 5: db.User.RANK_MODERATOR, - 4: db.User.RANK_POWER, - 3: db.User.RANK_REGULAR, - 2: db.User.RANK_RESTRICTED, - 1: db.User.RANK_ANONYMOUS, - }[row['accessRank']] - user.creation_time = row['creationTime'] - user.last_login_time = row['lastLoginTime'] - user.avatar_style = { - 2: db.User.AVATAR_MANUAL, - 1: db.User.AVATAR_GRAVATAR, - 0: db.User.AVATAR_GRAVATAR, - }[row['avatarStyle']] - v2_session.add(user) - - if user.avatar_style == db.User.AVATAR_MANUAL: - source_avatar_path = os.path.join( - v1_data_dir, 'public_html', 'data', 'avatars', str(user.user_id)) - target_avatar_path = 'avatars/' + user.name.lower() + '.png' - if not no_data and not files.has(target_avatar_path): - avatar_content = read_file(source_avatar_path) - image = images.Image(avatar_content) - image.resize_fill( - int(config.config['thumbnails']['avatar_width']), - int(config.config['thumbnails']['avatar_height'])) - files.save(target_avatar_path, image.to_png()) - counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM users') + 1 - v2_session.execute('ALTER SEQUENCE user_id_seq RESTART WITH %d' % counter) - v2_session.commit() - -def import_tag_categories(v1_session, v2_session): - category_to_id_map = {} - for row in exec_query(v1_session, 'SELECT DISTINCT category FROM tags'): - logger.info('Importing tag category %s...', row['category']) - category = db.TagCategory() - category.tag_category_id = len(category_to_id_map) + 1 - category.name = row['category'] - category.color = 'default' - v2_session.add(category) - category_to_id_map[category.name] = category.tag_category_id - v2_session.execute( - 'ALTER SEQUENCE tag_category_id_seq RESTART WITH %d' % ( - len(category_to_id_map) + 1,)) - return category_to_id_map - -def import_tags(category_to_id_map, v1_session, v2_session): - unused_tag_ids = [] - for row in exec_query(v1_session, 'SELECT * FROM tags'): - logger.info('Importing tag %s...', row['name']) - if row['banned']: - logger.info('Ignored banned tag %s', row['name']) - unused_tag_ids.append(row['id']) - continue - tag = db.Tag() - tag.tag_id = row['id'] - tag.names = [db.TagName(name=row['name'])] - tag.category_id = category_to_id_map[row['category']] - tag.creation_time = row['creationTime'] - tag.last_edit_time = row['lastEditTime'] - v2_session.add(tag) - counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM tags') + 1 - v2_session.execute('ALTER SEQUENCE tag_id_seq RESTART WITH %d' % counter) - v2_session.commit() - return unused_tag_ids - -def import_tag_relations(unused_tag_ids, v1_session, v2_session): - logger.info('Importing tag relations...') - for row in exec_query(v1_session, 'SELECT * FROM tagRelations'): - if row['tag1id'] in unused_tag_ids or row['tag2id'] in unused_tag_ids: - continue - if row['type'] == 1: - v2_session.add( - db.TagImplication( - parent_id=row['tag1id'], child_id=row['tag2id'])) - else: - v2_session.add( - db.TagSuggestion( - parent_id=row['tag1id'], child_id=row['tag2id'])) - v2_session.commit() - -def import_posts(v1_session, v2_session): - unused_post_ids = [] - for row in exec_query(v1_session, 'SELECT * FROM posts'): - logger.info('Importing post %d...', row['id']) - if row['contentType'] == 4: - logger.warn('Ignoring youtube post %d', row['id']) - unused_post_ids.append(row['id']) - continue - post = db.Post() - post.post_id = row['id'] - post.user_id = row['userId'] - post.type = { - 1: db.Post.TYPE_IMAGE, - 2: db.Post.TYPE_FLASH, - 3: db.Post.TYPE_VIDEO, - 5: db.Post.TYPE_ANIMATION, - }[row['contentType']] - post.source = row['source'] - post.canvas_width = row['imageWidth'] - post.canvas_height = row['imageHeight'] - post.file_size = row['originalFileSize'] - post.creation_time = row['creationTime'] - post.last_edit_time = row['lastEditTime'] - post.checksum = row['contentChecksum'] - post.mime_type = row['contentMimeType'] - post.safety = { - 1: db.Post.SAFETY_SAFE, - 2: db.Post.SAFETY_SKETCHY, - 3: db.Post.SAFETY_UNSAFE, - }[row['safety']] - if row['flags'] & 1: - post.flags = [db.Post.FLAG_LOOP] - v2_session.add(post) - counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM posts') + 1 - v2_session.execute('ALTER SEQUENCE post_id_seq RESTART WITH %d' % counter) - v2_session.commit() - return unused_post_ids - -def _import_post_content_for_post( - unused_post_ids, v1_data_dir, v1_session, v2_session, row, post): - logger.info('Importing post %d content...', row['id']) - if row['id'] in unused_post_ids: - logger.warn('Ignoring unimported post %d', row['id']) - return - assert post - source_content_path = os.path.join( - v1_data_dir, - 'public_html', - 'data', - 'posts', - row['name']) - source_thumb_path = os.path.join( - v1_data_dir, - 'public_html', - 'data', - 'posts', - row['name'] + '-custom-thumb') - target_content_path = posts.get_post_content_path(post) - target_thumb_path = posts.get_post_thumbnail_backup_path(post) - if not files.has(target_content_path): - post_content = read_file(source_content_path) - files.save(target_content_path, post_content) - if os.path.exists(source_thumb_path) and not files.has(target_thumb_path): - thumb_content = read_file(source_thumb_path) - files.save(target_thumb_path, thumb_content) - if not files.has(posts.get_post_thumbnail_path(post)): - posts.generate_post_thumbnail(post) - -def import_post_content(unused_post_ids, v1_data_dir, v1_session, v2_session): - rows = list(exec_query(v1_session, 'SELECT * FROM posts')) - posts = {post.post_id: post for post in v2_session.query(db.Post).all()} - with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: - for row in rows: - post = posts.get(row['id']) - executor.submit( - _import_post_content_for_post, - unused_post_ids, v1_data_dir, v1_session, v2_session, row, post) - -def import_post_tags(unused_post_ids, v1_session, v2_session): - logger.info('Importing post tags...') - for row in exec_query(v1_session, 'SELECT * FROM postTags'): - if row['postId'] in unused_post_ids: - continue - v2_session.add(db.PostTag(post_id=row['postId'], tag_id=row['tagId'])) - v2_session.commit() - -def import_post_notes(unused_post_ids, v1_session, v2_session): - logger.info('Importing post notes...') - for row in exec_query(v1_session, 'SELECT * FROM postNotes'): - if row['postId'] in unused_post_ids: - continue - post_note = db.PostNote() - post_note.post_id = row['postId'] - post_note.text = row['text'] - post_note.polygon = translate_note_polygon(row) - v2_session.add(post_note) - v2_session.commit() - -def import_post_relations(unused_post_ids, v1_session, v2_session): - logger.info('Importing post relations...') - for row in exec_query(v1_session, 'SELECT * FROM postRelations'): - if row['post1id'] in unused_post_ids or row['post2id'] in unused_post_ids: - continue - v2_session.add( - db.PostRelation( - parent_id=row['post1id'], child_id=row['post2id'])) - v2_session.add( - db.PostRelation( - parent_id=row['post2id'], child_id=row['post1id'])) - v2_session.commit() - -def import_post_favorites(unused_post_ids, v1_session, v2_session): - logger.info('Importing post favorites...') - for row in exec_query(v1_session, 'SELECT * FROM favorites'): - if row['postId'] in unused_post_ids: - continue - v2_session.add( - db.PostFavorite( - post_id=row['postId'], - user_id=row['userId'], - time=row['time'] or datetime.datetime.min)) - v2_session.commit() - -def import_comments(unused_post_ids, v1_session, v2_session): - for row in exec_query(v1_session, 'SELECT * FROM comments'): - logger.info('Importing comment %d...', row['id']) - if row['postId'] in unused_post_ids: - logger.warn('Ignoring comment for unimported post %d', row['postId']) - continue - if not posts.try_get_post_by_id(row['postId']): - logger.warn('Ignoring comment for non existing post %d', row['postId']) - continue - comment = db.Comment() - comment.comment_id = row['id'] - comment.user_id = row['userId'] - comment.post_id = row['postId'] - comment.creation_time = row['creationTime'] - comment.last_edit_time = row['lastEditTime'] - comment.text = row['text'] - v2_session.add(comment) - counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM comments') + 1 - v2_session.execute('ALTER SEQUENCE comment_id_seq RESTART WITH %d' % counter) - v2_session.commit() - -def import_scores(v1_session, v2_session): - logger.info('Importing scores...') - for row in exec_query(v1_session, 'SELECT * FROM scores'): - if row['postId']: - post = posts.try_get_post_by_id(row['postId']) - if not post: - logger.warn('Ignoring score for unimported post %d', row['postId']) - continue - score = db.PostScore() - score.post = post - elif row['commentId']: - comment = comments.try_get_comment_by_id(row['commentId']) - if not comment: - logger.warn('Ignoring score for unimported comment %d', row['commentId']) - continue - score = db.CommentScore() - score.comment = comment - score.score = row['score'] - score.time = row['time'] or datetime.datetime.min - score.user_id = row['userId'] - v2_session.add(score) - v2_session.commit() - -def main(): - args = parse_args() - - v1_data_dir = args.data_dir - v1_session = get_v1_session(args) - v2_session = db.session - - if v2_session.query(db.Tag).count() \ - or v2_session.query(db.Post).count() \ - or v2_session.query(db.Comment).count() \ - or v2_session.query(db.User).count(): - logger.error('v2.x database is dirty! Aborting.') - sys.exit(1) - - import_users(v1_data_dir, v1_session, v2_session, args.no_data) - category_to_id_map = import_tag_categories(v1_session, v2_session) - unused_tag_ids = import_tags(category_to_id_map, v1_session, v2_session) - import_tag_relations(unused_tag_ids, v1_session, v2_session) - unused_post_ids = import_posts(v1_session, v2_session) - if not args.no_data: - import_post_content(unused_post_ids, v1_data_dir, v1_session, v2_session) - import_post_tags(unused_post_ids, v1_session, v2_session) - import_post_notes(unused_post_ids, v1_session, v2_session) - import_post_relations(unused_post_ids, v1_session, v2_session) - import_post_favorites(unused_post_ids, v1_session, v2_session) - import_comments(unused_post_ids, v1_session, v2_session) - import_scores(v1_session, v2_session) - -if __name__ == '__main__': - main() From 350e9dd3310e95e19a2bd42f55d3d3aea25a0380 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 15:34:18 +0100 Subject: [PATCH 030/159] server/scripts: replace ./test with setup.cfg --- server/setup.cfg | 3 +++ server/test | 8 -------- 2 files changed, 3 insertions(+), 8 deletions(-) create mode 100644 server/setup.cfg delete mode 100755 server/test diff --git a/server/setup.cfg b/server/setup.cfg new file mode 100644 index 00000000..7e835b45 --- /dev/null +++ b/server/setup.cfg @@ -0,0 +1,3 @@ +[tool:pytest] +testpaths=szurubooru +addopts=--cov-report=term-missing --cov=szurubooru --tb=short diff --git a/server/test b/server/test deleted file mode 100755 index 69cfe542..00000000 --- a/server/test +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env python3 -import pytest -import sys -pytest.main([ - '--cov-report=term-missing', - '--cov=szurubooru', - '--tb=short', - ] + (sys.argv[1:] or ['szurubooru'])) From 705967d0fbe38fc3004652792b7b53c4365e2925 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 15:56:27 +0100 Subject: [PATCH 031/159] server/scripts: remove lint Any configuration for pycodestyle should go to the new setup.cfg file. --- server/lint | 3 --- 1 file changed, 3 deletions(-) delete mode 100755 server/lint diff --git a/server/lint b/server/lint deleted file mode 100755 index 218d3bb4..00000000 --- a/server/lint +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/sh -pylint szurubooru -pycodestyle szurubooru From e725f4f99c3d1100ec583706b9c9482271c28532 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 16:08:46 +0100 Subject: [PATCH 032/159] server/api: extra validation of list fields --- server/szurubooru/api/post_api.py | 14 ++++++++------ server/szurubooru/api/tag_api.py | 12 ++++++------ server/szurubooru/rest/context.py | 22 ++++++++++++++++++++++ 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 6c76688f..4b364ec4 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -47,14 +47,14 @@ def create_post( else: auth.verify_privilege(ctx.user, 'posts:create:identified') content = ctx.get_file('content') - tag_names = ctx.get_param_as_list('tags', default=[]) + tag_names = ctx.get_param_as_string_list('tags', default=[]) safety = ctx.get_param_as_string('safety') source = ctx.get_param_as_string('source', default='') if ctx.has_param('contentUrl') and not source: source = ctx.get_param_as_string('contentUrl', default='') - relations = ctx.get_param_as_list('relations', default=[]) + relations = ctx.get_param_as_int_list('relations', default=[]) notes = ctx.get_param_as_list('notes', default=[]) - flags = ctx.get_param_as_list('flags', default=[]) + flags = ctx.get_param_as_string_list('flags', default=[]) post, new_tags = posts.create_post( content, tag_names, None if anonymous else ctx.user) @@ -94,7 +94,8 @@ def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: posts.update_post_content(post, ctx.get_file('content')) if ctx.has_param('tags'): auth.verify_privilege(ctx.user, 'posts:edit:tags') - new_tags = posts.update_post_tags(post, ctx.get_param_as_list('tags')) + new_tags = posts.update_post_tags( + post, ctx.get_param_as_string_list('tags')) if len(new_tags): auth.verify_privilege(ctx.user, 'tags:create') db.session.flush() @@ -110,13 +111,14 @@ def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: posts.update_post_source(post, ctx.get_param_as_string('contentUrl')) if ctx.has_param('relations'): auth.verify_privilege(ctx.user, 'posts:edit:relations') - posts.update_post_relations(post, ctx.get_param_as_list('relations')) + posts.update_post_relations( + post, ctx.get_param_as_int_list('relations')) if ctx.has_param('notes'): auth.verify_privilege(ctx.user, 'posts:edit:notes') posts.update_post_notes(post, ctx.get_param_as_list('notes')) if ctx.has_param('flags'): auth.verify_privilege(ctx.user, 'posts:edit:flags') - posts.update_post_flags(post, ctx.get_param_as_list('flags')) + posts.update_post_flags(post, ctx.get_param_as_string_list('flags')) if ctx.has_file('thumbnail'): auth.verify_privilege(ctx.user, 'posts:edit:thumbnail') posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index 7a379b3c..8e1afbdf 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -38,11 +38,11 @@ def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'tags:create') - names = ctx.get_param_as_list('names') + names = ctx.get_param_as_string_list('names') category = ctx.get_param_as_string('category') description = ctx.get_param_as_string('description', default='') - suggestions = ctx.get_param_as_list('suggestions', default=[]) - implications = ctx.get_param_as_list('implications', default=[]) + suggestions = ctx.get_param_as_string_list('suggestions', default=[]) + implications = ctx.get_param_as_string_list('implications', default=[]) _create_if_needed(suggestions, ctx.user) _create_if_needed(implications, ctx.user) @@ -71,7 +71,7 @@ def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: versions.bump_version(tag) if ctx.has_param('names'): auth.verify_privilege(ctx.user, 'tags:edit:names') - tags.update_tag_names(tag, ctx.get_param_as_list('names')) + tags.update_tag_names(tag, ctx.get_param_as_string_list('names')) if ctx.has_param('category'): auth.verify_privilege(ctx.user, 'tags:edit:category') tags.update_tag_category_name( @@ -82,12 +82,12 @@ def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: tag, ctx.get_param_as_string('description')) if ctx.has_param('suggestions'): auth.verify_privilege(ctx.user, 'tags:edit:suggestions') - suggestions = ctx.get_param_as_list('suggestions') + suggestions = ctx.get_param_as_string_list('suggestions') _create_if_needed(suggestions, ctx.user) tags.update_tag_suggestions(tag, suggestions) if ctx.has_param('implications'): auth.verify_privilege(ctx.user, 'tags:edit:implications') - implications = ctx.get_param_as_list('implications') + implications = ctx.get_param_as_string_list('implications') _create_if_needed(implications, ctx.user) tags.update_tag_implications(tag, implications) tag.last_edit_time = datetime.utcnow() diff --git a/server/szurubooru/rest/context.py b/server/szurubooru/rest/context.py index bb33bfab..0f618add 100644 --- a/server/szurubooru/rest/context.py +++ b/server/szurubooru/rest/context.py @@ -86,6 +86,28 @@ class Context: raise errors.InvalidParameterError( 'Parameter %r must be a list.' % name) + def get_param_as_int_list( + self, + name: str, + default: Union[object, List[int]]=MISSING) -> List[int]: + ret = self.get_param_as_list(name, default) + for item in ret: + if type(item) is not int: + raise errors.InvalidParameterError( + 'Parameter %r must be a list of integer values.' % name) + return ret + + def get_param_as_string_list( + self, + name: str, + default: Union[object, List[str]]=MISSING) -> List[str]: + ret = self.get_param_as_list(name, default) + for item in ret: + if type(item) is not str: + raise errors.InvalidParameterError( + 'Parameter %r must be a list of string values.' % name) + return ret + def get_param_as_string( self, name: str, From 6cc18be68d979f6d1811fb95700a10c22da1a518 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 16:54:01 +0100 Subject: [PATCH 033/159] client/posts: fix editing post relations Regression since e725f4f9 --- client/js/controls/post_edit_sidebar_control.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/js/controls/post_edit_sidebar_control.js b/client/js/controls/post_edit_sidebar_control.js index 7ba37d3d..70570d75 100644 --- a/client/js/controls/post_edit_sidebar_control.js +++ b/client/js/controls/post_edit_sidebar_control.js @@ -274,7 +274,8 @@ class PostEditSidebarControl extends events.EventTarget { undefined, relations: this._relationsInputNode ? - misc.splitByWhitespace(this._relationsInputNode.value) : + misc.splitByWhitespace(this._relationsInputNode.value) + .map(x => parseInt(x)) : undefined, content: this._newPostContent ? From 1f14f2fc165f489dcd2b9b964d0659f35ba41d5b Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 21:46:24 +0100 Subject: [PATCH 034/159] docs/api: add info about wildcards --- API.md | 54 +++++++++++++++---------------- client/html/help_search_posts.tpl | 8 ++--- client/html/help_search_tags.tpl | 2 +- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/API.md b/API.md index b060e92c..ec0bb392 100644 --- a/API.md +++ b/API.md @@ -431,21 +431,21 @@ data. **Named tokens** - | `` | Description | - | ------------------- | ------------------------------------- | - | `name` | having given name (accepts wildcards) | - | `category` | having given category | - | `creation-date` | created at given date | - | `creation-time` | alias of `creation-date` | - | `last-edit-date` | edited at given date | - | `last-edit-time` | alias of `last-edit-date` | - | `edit-date` | alias of `last-edit-date` | - | `edit-time` | alias of `last-edit-date` | - | `usages` | used in given number of posts | - | `usage-count` | alias of `usages` | - | `post-count` | alias of `usages` | - | `suggestion-count` | with given number of suggestions | - | `implication-count` | with given number of implications | + | `` | Description | + | ------------------- | ----------------------------------------- | + | `name` | having given name (accepts wildcards) | + | `category` | having given category (accepts wildcards) | + | `creation-date` | created at given date | + | `creation-time` | alias of `creation-date` | + | `last-edit-date` | edited at given date | + | `last-edit-time` | alias of `last-edit-date` | + | `edit-date` | alias of `last-edit-date` | + | `edit-time` | alias of `last-edit-date` | + | `usages` | used in given number of posts | + | `usage-count` | alias of `usages` | + | `post-count` | alias of `usages` | + | `suggestion-count` | with given number of suggestions | + | `implication-count` | with given number of implications | **Sort style tokens** @@ -699,13 +699,13 @@ data. | `` | Description | | ------------------ | ---------------------------------------------------------- | | `id` | having given post number | - | `tag` | having given tag | + | `tag` | having given tag (accepts wildcards) | | `score` | having given score | - | `uploader` | uploaded by given user | + | `uploader` | uploaded by given user (accepts wildcards) | | `upload` | alias of upload | | `submit` | alias of upload | - | `comment` | commented by given user | - | `fav` | favorited by given user | + | `comment` | commented by given user (accepts wildcards) | + | `fav` | favorited by given user (accepts wildcards) | | `tag-count` | having given number of tags | | `comment-count` | having given number of comments | | `fav-count` | favorited by given number of users | @@ -1555,14 +1555,14 @@ data. **Named tokens** - | `` | Description | - | ----------------- | --------------------------------------------- | - | `type` | involving given resource type | - | `id` | involving given resource id | - | `date` | created at given date | - | `time` | alias of `date` | - | `operation` | `modified`, `created`, `deleted` or `merged` | - | `user` | name of the user that created given snapshot | + | `` | Description | + | ----------------- | ---------------------------------------------------------------- | + | `type` | involving given resource type | + | `id` | involving given resource id | + | `date` | created at given date | + | `time` | alias of `date` | + | `operation` | `modified`, `created`, `deleted` or `merged` | + | `user` | name of the user that created given snapshot (accepts wildcards) | **Sort style tokens** diff --git a/client/html/help_search_posts.tpl b/client/html/help_search_posts.tpl index 074819f9..f1f062c7 100644 --- a/client/html/help_search_posts.tpl +++ b/client/html/help_search_posts.tpl @@ -12,7 +12,7 @@ tag - having given tag + having given tag (accepts wildcards) score @@ -20,7 +20,7 @@ uploader - uploaded by given user + uploaded by given use (accepts wildcards)r upload @@ -32,11 +32,11 @@ comment - commented by given user + commented by given user (accepts wildcards) fav - favorited by given user + favorited by given user (accepts wildcards) tag-count diff --git a/client/html/help_search_tags.tpl b/client/html/help_search_tags.tpl index 38697341..b6fbeccd 100644 --- a/client/html/help_search_tags.tpl +++ b/client/html/help_search_tags.tpl @@ -12,7 +12,7 @@ category - having given category + having given category (accepts wildcards) creation-date From 0b21d98c9bb26db6b5e9aa40e8877af33ec7a390 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 21:50:58 +0100 Subject: [PATCH 035/159] server/posts: support note-text search query --- API.md | 1 + client/html/help_search_posts.tpl | 4 ++ .../search/configs/post_search_config.py | 55 ++++++++++++------- .../search/configs/test_post_search_config.py | 23 +++++++- 4 files changed, 61 insertions(+), 22 deletions(-) diff --git a/API.md b/API.md index ec0bb392..08e2c5c4 100644 --- a/API.md +++ b/API.md @@ -710,6 +710,7 @@ data. | `comment-count` | having given number of comments | | `fav-count` | favorited by given number of users | | `note-count` | having given number of annotations | + | `note-text` | having given note text (accepts wildcards) | | `relation-count` | having given number of relations | | `feature-count` | having been featured given number of times | | `type` | given type of posts. `` can be either `image`, `animation` (or `animated` or `anim`), `flash` (or `swf`) or `video` (or `webm`). | diff --git a/client/html/help_search_posts.tpl b/client/html/help_search_posts.tpl index f1f062c7..1c4ea86e 100644 --- a/client/html/help_search_posts.tpl +++ b/client/html/help_search_posts.tpl @@ -54,6 +54,10 @@ note-count having given number of annotations + + note-text + having given note text (accepts wildcards) + relation-count having given number of relations diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index cda1b1ac..003b4c25 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -69,25 +69,35 @@ def _create_score_filter(score: int) -> Filter: return wrapper -def _create_user_filter() -> Filter: - def wrapper( - query: SaQuery, - criterion: Optional[criteria.BaseCriterion], - negated: bool) -> SaQuery: - assert criterion - if isinstance(criterion, criteria.PlainCriterion) \ - and not criterion.value: - # pylint: disable=singleton-comparison - expr = model.Post.user_id == None - if negated: - expr = ~expr - return query.filter(expr) - return search_util.create_subquery_filter( - model.Post.user_id, - model.User.user_id, - model.User.name, - search_util.create_str_filter)(query, criterion, negated) - return wrapper +def _user_filter( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion + if isinstance(criterion, criteria.PlainCriterion) \ + and not criterion.value: + # pylint: disable=singleton-comparison + expr = model.Post.user_id == None + if negated: + expr = ~expr + return query.filter(expr) + return search_util.create_subquery_filter( + model.Post.user_id, + model.User.user_id, + model.User.name, + search_util.create_str_filter)(query, criterion, negated) + + +def _note_filter( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion + return search_util.create_subquery_filter( + model.Post.post_id, + model.PostNote.post_id, + model.PostNote.text, + search_util.create_str_filter)(query, criterion, negated) class PostSearchConfig(BaseSearchConfig): @@ -187,7 +197,7 @@ class PostSearchConfig(BaseSearchConfig): ( ['uploader', 'upload', 'submit'], - _create_user_filter() + _user_filter ), ( @@ -311,6 +321,11 @@ class PostSearchConfig(BaseSearchConfig): search_util.create_str_filter( model.Post.safety, _safety_transformer) ), + + ( + ['note-text'], + _note_filter + ), ]) @property diff --git a/server/szurubooru/tests/search/configs/test_post_search_config.py b/server/szurubooru/tests/search/configs/test_post_search_config.py index 945a5e4f..738ee410 100644 --- a/server/szurubooru/tests/search/configs/test_post_search_config.py +++ b/server/szurubooru/tests/search/configs/test_post_search_config.py @@ -27,8 +27,8 @@ def score_factory(user_factory): @pytest.fixture def note_factory(): - def factory(): - return model.PostNote(polygon='...', text='...') + def factory(text='...'): + return model.PostNote(polygon='...', text=text) return factory @@ -294,6 +294,25 @@ def test_filter_by_note_count( verify_unpaged(input, expected_post_ids) +@pytest.mark.parametrize('input,expected_post_ids', [ + ('note-text:*', [1, 2, 3]), + ('note-text:text2', [2]), + ('note-text:text3*', [3]), + ('note-text:text3a,text2', [2, 3]), +]) +def test_filter_by_note_count( + verify_unpaged, post_factory, note_factory, input, expected_post_ids): + post1 = post_factory(id=1) + post2 = post_factory(id=2) + post3 = post_factory(id=3) + post1.notes = [note_factory(text='text1')] + post2.notes = [note_factory(text='text2'), note_factory(text='text2')] + post3.notes = [note_factory(text='text3a'), note_factory(text='text3b')] + db.session.add_all([post1, post2, post3]) + db.session.flush() + verify_unpaged(input, expected_post_ids) + + @pytest.mark.parametrize('input,expected_post_ids', [ ('feature-count:1', [1]), ('feature-count:3', [3]), From 00c3a4320bfbb4ee80ba132232a279a8982570df Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 22:09:33 +0100 Subject: [PATCH 036/159] server/posts: support aspect-ratio search query --- API.md | 88 ++++++++++--------- client/html/help_search_posts.tpl | 16 ++++ server/szurubooru/model/post.py | 1 + .../search/configs/post_search_config.py | 7 ++ server/szurubooru/search/configs/util.py | 38 ++++++-- .../search/configs/test_post_search_config.py | 32 ++++--- 6 files changed, 120 insertions(+), 62 deletions(-) diff --git a/API.md b/API.md index 08e2c5c4..8b25a516 100644 --- a/API.md +++ b/API.md @@ -696,48 +696,52 @@ data. **Named tokens** - | `` | Description | - | ------------------ | ---------------------------------------------------------- | - | `id` | having given post number | - | `tag` | having given tag (accepts wildcards) | - | `score` | having given score | - | `uploader` | uploaded by given user (accepts wildcards) | - | `upload` | alias of upload | - | `submit` | alias of upload | - | `comment` | commented by given user (accepts wildcards) | - | `fav` | favorited by given user (accepts wildcards) | - | `tag-count` | having given number of tags | - | `comment-count` | having given number of comments | - | `fav-count` | favorited by given number of users | - | `note-count` | having given number of annotations | - | `note-text` | having given note text (accepts wildcards) | - | `relation-count` | having given number of relations | - | `feature-count` | having been featured given number of times | - | `type` | given type of posts. `` can be either `image`, `animation` (or `animated` or `anim`), `flash` (or `swf`) or `video` (or `webm`). | - | `content-checksum` | having given SHA1 checksum | - | `file-size` | having given file size (in bytes) | - | `image-width` | having given image width (where applicable) | - | `image-height` | having given image height (where applicable) | - | `image-area` | having given number of pixels (image width * image height) | - | `width` | alias of `image-width` | - | `height` | alias of `image-height` | - | `area` | alias of `image-area` | - | `creation-date` | posted at given date | - | `creation-time` | alias of `creation-date` | - | `date` | alias of `creation-date` | - | `time` | alias of `creation-date` | - | `last-edit-date` | edited at given date | - | `last-edit-time` | alias of `last-edit-date` | - | `edit-date` | alias of `last-edit-date` | - | `edit-time` | alias of `last-edit-date` | - | `comment-date` | commented at given date | - | `comment-time` | alias of `comment-date` | - | `fav-date` | last favorited at given date | - | `fav-time` | alias of `fav-date` | - | `feature-date` | featured at given date | - | `feature-time` | alias of `feature-time` | - | `safety` | having given safety. `` can be either `safe`, `sketchy` (or `questionable`) or `unsafe`. | - | `rating` | alias of `safety` | + | `` | Description | + | -------------------- | ---------------------------------------------------------- | + | `id` | having given post number | + | `tag` | having given tag (accepts wildcards) | + | `score` | having given score | + | `uploader` | uploaded by given user (accepts wildcards) | + | `upload` | alias of upload | + | `submit` | alias of upload | + | `comment` | commented by given user (accepts wildcards) | + | `fav` | favorited by given user (accepts wildcards) | + | `tag-count` | having given number of tags | + | `comment-count` | having given number of comments | + | `fav-count` | favorited by given number of users | + | `note-count` | having given number of annotations | + | `note-text` | having given note text (accepts wildcards) | + | `relation-count` | having given number of relations | + | `feature-count` | having been featured given number of times | + | `type` | given type of posts. `` can be either `image`, `animation` (or `animated` or `anim`), `flash` (or `swf`) or `video` (or `webm`). | + | `content-checksum` | having given SHA1 checksum | + | `file-size` | having given file size (in bytes) | + | `image-width` | having given image width (where applicable) | + | `image-height` | having given image height (where applicable) | + | `image-area` | having given number of pixels (image width * image height) | + | `image-aspect-ratio` | having given aspect ratio (image width / image height) | + | `image-ar` | alias of `image-aspect-ratio` | + | `width` | alias of `image-width` | + | `height` | alias of `image-height` | + | `area` | alias of `image-area` | + | `ar` | alias of `image-aspect-ratio` | + | `aspect-ratio` | alias of `image-aspect-ratio` | + | `creation-date` | posted at given date | + | `creation-time` | alias of `creation-date` | + | `date` | alias of `creation-date` | + | `time` | alias of `creation-date` | + | `last-edit-date` | edited at given date | + | `last-edit-time` | alias of `last-edit-date` | + | `edit-date` | alias of `last-edit-date` | + | `edit-time` | alias of `last-edit-date` | + | `comment-date` | commented at given date | + | `comment-time` | alias of `comment-date` | + | `fav-date` | last favorited at given date | + | `fav-time` | alias of `fav-date` | + | `feature-date` | featured at given date | + | `feature-time` | alias of `feature-time` | + | `safety` | having given safety. `` can be either `safe`, `sketchy` (or `questionable`) or `unsafe`. | + | `rating` | alias of `safety` | **Sort style tokens** diff --git a/client/html/help_search_posts.tpl b/client/html/help_search_posts.tpl index 1c4ea86e..a9e1a8e0 100644 --- a/client/html/help_search_posts.tpl +++ b/client/html/help_search_posts.tpl @@ -90,6 +90,14 @@ image-area having given number of pixels (image width * image height) + + image-aspect-ratio + having given aspect ratio (image width / image height) + + + image-ar + alias of image-aspect-ratio + width alias of image-width @@ -102,6 +110,14 @@ area alias of image-area + + aspect-ratio + alias of image-aspect-ratio + + + ar + alias of image-aspect-ratio + creation-date posted at given date diff --git a/server/szurubooru/model/post.py b/server/szurubooru/model/post.py index 23f52b57..0aa04e58 100644 --- a/server/szurubooru/model/post.py +++ b/server/szurubooru/model/post.py @@ -194,6 +194,7 @@ class Post(Base): .correlate_except(PostTag)) canvas_area = column_property(canvas_width * canvas_height) + canvas_aspect_ratio = column_property(canvas_width / canvas_height) @property def is_featured(self) -> bool: diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index 003b4c25..34e86537 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -290,6 +290,13 @@ class PostSearchConfig(BaseSearchConfig): search_util.create_num_filter(model.Post.canvas_area) ), + ( + ['image-aspect-ratio', 'image-ar', 'aspect-ratio', 'ar'], + search_util.create_num_filter( + model.Post.canvas_aspect_ratio, + transformer=search_util.float_transformer) + ), + ( ['creation-date', 'creation-time', 'date', 'time'], search_util.create_date_filter(model.Post.creation_time) diff --git a/server/szurubooru/search/configs/util.py b/server/szurubooru/search/configs/util.py index 086f3921..c6b9b783 100644 --- a/server/szurubooru/search/configs/util.py +++ b/server/szurubooru/search/configs/util.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Callable +from typing import Any, Optional, Union, Callable import sqlalchemy as sa from szurubooru import db, errors from szurubooru.func import util @@ -7,6 +7,9 @@ from szurubooru.search.typing import SaColumn, SaQuery from szurubooru.search.configs.base_search_config import Filter +Number = Union[int, float] + + def wildcard_transformer(value: str) -> str: return ( value @@ -16,22 +19,37 @@ def wildcard_transformer(value: str) -> str: .replace('*', '%')) +def integer_transformer(value: str) -> int: + return int(value) + + +def float_transformer(value: str) -> float: + for sep in list('/:'): + if sep in value: + a, b = value.split(sep, 1) + return float(a) / float(b) + return float(value) + + def apply_num_criterion_to_column( - column: Any, criterion: criteria.BaseCriterion) -> Any: + column: Any, + criterion: criteria.BaseCriterion, + transformer: Callable[[str], Number]=integer_transformer) -> SaQuery: try: if isinstance(criterion, criteria.PlainCriterion): - expr = column == int(criterion.value) + expr = column == transformer(criterion.value) elif isinstance(criterion, criteria.ArrayCriterion): - expr = column.in_(int(value) for value in criterion.values) + expr = column.in_(transformer(value) for value in criterion.values) elif isinstance(criterion, criteria.RangedCriterion): assert criterion.min_value or criterion.max_value if criterion.min_value and criterion.max_value: expr = column.between( - int(criterion.min_value), int(criterion.max_value)) + transformer(criterion.min_value), + transformer(criterion.max_value)) elif criterion.min_value: - expr = column >= int(criterion.min_value) + expr = column >= transformer(criterion.min_value) elif criterion.max_value: - expr = column <= int(criterion.max_value) + expr = column <= transformer(criterion.max_value) else: assert False except ValueError: @@ -40,13 +58,15 @@ def apply_num_criterion_to_column( return expr -def create_num_filter(column: Any) -> Filter: +def create_num_filter( + column: Any, + transformer: Callable[[str], Number]=integer_transformer) -> SaQuery: def wrapper( query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool) -> SaQuery: assert criterion - expr = apply_num_criterion_to_column(column, criterion) + expr = apply_num_criterion_to_column(column, criterion, transformer) if negated: expr = ~expr return query.filter(expr) diff --git a/server/szurubooru/tests/search/configs/test_post_search_config.py b/server/szurubooru/tests/search/configs/test_post_search_config.py index 738ee410..551e09c2 100644 --- a/server/szurubooru/tests/search/configs/test_post_search_config.py +++ b/server/szurubooru/tests/search/configs/test_post_search_config.py @@ -422,14 +422,19 @@ def test_filter_by_file_size( @pytest.mark.parametrize('input,expected_post_ids', [ ('image-width:100', [1]), - ('image-width:102', [3]), - ('image-width:100,102', [1, 3]), + ('image-width:200', [2]), + ('image-width:100,300', [1, 3]), ('image-height:200', [1]), - ('image-height:202', [3]), - ('image-height:200,202', [1, 3]), - ('image-area:20000', [1]), - ('image-area:20604', [3]), - ('image-area:20000,20604', [1, 3]), + ('image-height:100', [2]), + ('image-height:200,300', [1, 3]), + ('image-area:20000', [1, 2]), + ('image-area:90000', [3]), + ('image-area:20000,90000', [1, 2, 3]), + ('image-ar:1', [3]), + ('image-ar:..0.9', [1]), + ('image-ar:1.1..', [2]), + ('image-ar:1/1..1/1', [3]), + ('image-ar:1:1..1:1', [3]), ]) def test_filter_by_image_size( verify_unpaged, post_factory, input, expected_post_ids): @@ -437,16 +442,21 @@ def test_filter_by_image_size( post2 = post_factory(id=2) post3 = post_factory(id=3) post1.canvas_width = 100 - post2.canvas_width = 101 - post3.canvas_width = 102 post1.canvas_height = 200 - post2.canvas_height = 201 - post3.canvas_height = 202 + post2.canvas_width = 200 + post2.canvas_height = 100 + post3.canvas_width = 300 + post3.canvas_height = 300 db.session.add_all([post1, post2, post3]) db.session.flush() verify_unpaged(input, expected_post_ids) +def test_filter_by_invalid_aspect_ratio(executor): + with pytest.raises(errors.SearchError): + executor.execute('image-ar:1:1:1', page=1, page_size=100) + + @pytest.mark.parametrize('input,expected_post_ids', [ ('creation-date:2014', [1]), ('creation-date:2016', [3]), From 4caa980bf8bc99c2a392220f608b199a517cc2d7 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 22:38:05 +0100 Subject: [PATCH 037/159] server/build: add missing dependency Althought szurubooru is now no longer dependent from image-match, the pulled code still needs the skimage library. --- server/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/server/requirements.txt b/server/requirements.txt index bef7c14d..7b6d53cf 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -10,3 +10,4 @@ pycodestyle>=2.0.0 scipy>=0.18.1 elasticsearch>=5.0.0 elasticsearch-dsl>=5.0.0 +skimage>=0.12 From f40a8875c441f34065df1ffd5afacdfc948f50a1 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 22:38:55 +0100 Subject: [PATCH 038/159] server/files: fix import for Py3.5 os.DirEntry is available only from Python3.6+. --- server/szurubooru/func/files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/szurubooru/func/files.py b/server/szurubooru/func/files.py index 0a992ee4..fa9f36fd 100644 --- a/server/szurubooru/func/files.py +++ b/server/szurubooru/func/files.py @@ -17,7 +17,7 @@ def has(path: str) -> bool: return os.path.exists(_get_full_path(path)) -def scan(path: str) -> List[os.DirEntry]: +def scan(path: str) -> List[Any]: if has(path): return list(os.scandir(_get_full_path(path))) return [] From 49e597525403b593b9721c81a9ef13484472f868 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 23:19:05 +0100 Subject: [PATCH 039/159] server/model: use new sqlalchemy import style --- server/szurubooru/model/comment.py | 60 ++++--- server/szurubooru/model/post.py | 226 +++++++++++++----------- server/szurubooru/model/snapshot.py | 32 ++-- server/szurubooru/model/tag.py | 86 ++++----- server/szurubooru/model/tag_category.py | 21 ++- 5 files changed, 227 insertions(+), 198 deletions(-) diff --git a/server/szurubooru/model/comment.py b/server/szurubooru/model/comment.py index 55c1596b..17b76a05 100644 --- a/server/szurubooru/model/comment.py +++ b/server/szurubooru/model/comment.py @@ -1,6 +1,4 @@ -from sqlalchemy import Column, Integer, DateTime, UnicodeText, ForeignKey -from sqlalchemy.orm import relationship, backref -from sqlalchemy.sql.expression import func +import sqlalchemy as sa from szurubooru.db import get_session from szurubooru.model.base import Base @@ -8,51 +6,59 @@ from szurubooru.model.base import Base class CommentScore(Base): __tablename__ = 'comment_score' - comment_id = Column( + comment_id = sa.Column( 'comment_id', - Integer, - ForeignKey('comment.id'), + sa.Integer, + sa.ForeignKey('comment.id'), nullable=False, primary_key=True) - user_id = Column( + user_id = sa.Column( 'user_id', - Integer, - ForeignKey('user.id'), + sa.Integer, + sa.ForeignKey('user.id'), nullable=False, primary_key=True, index=True) - time = Column('time', DateTime, nullable=False) - score = Column('score', Integer, nullable=False) + time = sa.Column('time', sa.DateTime, nullable=False) + score = sa.Column('score', sa.Integer, nullable=False) - comment = relationship('Comment') - user = relationship( + comment = sa.orm.relationship('Comment') + user = sa.orm.relationship( 'User', - backref=backref('comment_scores', cascade='all, delete-orphan')) + backref=sa.orm.backref('comment_scores', cascade='all, delete-orphan')) class Comment(Base): __tablename__ = 'comment' - comment_id = Column('id', Integer, primary_key=True) - post_id = Column( - 'post_id', Integer, ForeignKey('post.id'), nullable=False, index=True) - user_id = Column( - 'user_id', Integer, ForeignKey('user.id'), nullable=True, index=True) - version = Column('version', Integer, default=1, nullable=False) - creation_time = Column('creation_time', DateTime, nullable=False) - last_edit_time = Column('last_edit_time', DateTime) - text = Column('text', UnicodeText, default=None) + comment_id = sa.Column('id', sa.Integer, primary_key=True) + post_id = sa.Column( + 'post_id', + sa.Integer, + sa.ForeignKey('post.id'), + nullable=False, + index=True) + user_id = sa.Column( + 'user_id', + sa.Integer, + sa.ForeignKey('user.id'), + nullable=True, + index=True) + version = sa.Column('version', sa.Integer, default=1, nullable=False) + creation_time = sa.Column('creation_time', sa.DateTime, nullable=False) + last_edit_time = sa.Column('last_edit_time', sa.DateTime) + text = sa.Column('text', sa.UnicodeText, default=None) - user = relationship('User') - post = relationship('Post') - scores = relationship( + user = sa.orm.relationship('User') + post = sa.orm.relationship('Post') + scores = sa.orm.relationship( 'CommentScore', cascade='all, delete-orphan', lazy='joined') @property def score(self) -> int: return ( get_session() - .query(func.sum(CommentScore.score)) + .query(sa.sql.expression.func.sum(CommentScore.score)) .filter(CommentScore.comment_id == self.comment_id) .one()[0] or 0) diff --git a/server/szurubooru/model/post.py b/server/szurubooru/model/post.py index 0aa04e58..d0de4855 100644 --- a/server/szurubooru/model/post.py +++ b/server/szurubooru/model/post.py @@ -1,8 +1,4 @@ -from sqlalchemy.sql.expression import func, select -from sqlalchemy import ( - Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey) -from sqlalchemy.orm import ( - relationship, column_property, object_session, backref) +import sqlalchemy as sa from szurubooru.model.base import Base from szurubooru.model.comment import Comment @@ -10,95 +6,109 @@ from szurubooru.model.comment import Comment class PostFeature(Base): __tablename__ = 'post_feature' - post_feature_id = Column('id', Integer, primary_key=True) - post_id = Column( - 'post_id', Integer, ForeignKey('post.id'), nullable=False, index=True) - user_id = Column( - 'user_id', Integer, ForeignKey('user.id'), nullable=False, index=True) - time = Column('time', DateTime, nullable=False) + post_feature_id = sa.Column('id', sa.Integer, primary_key=True) + post_id = sa.Column( + 'post_id', + sa.Integer, + sa.ForeignKey('post.id'), + nullable=False, + index=True) + user_id = sa.Column( + 'user_id', + sa.Integer, + sa.ForeignKey('user.id'), + nullable=False, + index=True) + time = sa.Column('time', sa.DateTime, nullable=False) - post = relationship('Post') # type: Post - user = relationship( - 'User', backref=backref('post_features', cascade='all, delete-orphan')) + post = sa.orm.relationship('Post') # type: Post + user = sa.orm.relationship( + 'User', + backref=sa.orm.backref( + 'post_features', cascade='all, delete-orphan')) class PostScore(Base): __tablename__ = 'post_score' - post_id = Column( + post_id = sa.Column( 'post_id', - Integer, - ForeignKey('post.id'), + sa.Integer, + sa.ForeignKey('post.id'), primary_key=True, nullable=False, index=True) - user_id = Column( + user_id = sa.Column( 'user_id', - Integer, - ForeignKey('user.id'), + sa.Integer, + sa.ForeignKey('user.id'), primary_key=True, nullable=False, index=True) - time = Column('time', DateTime, nullable=False) - score = Column('score', Integer, nullable=False) + time = sa.Column('time', sa.DateTime, nullable=False) + score = sa.Column('score', sa.Integer, nullable=False) - post = relationship('Post') - user = relationship( + post = sa.orm.relationship('Post') + user = sa.orm.relationship( 'User', - backref=backref('post_scores', cascade='all, delete-orphan')) + backref=sa.orm.backref('post_scores', cascade='all, delete-orphan')) class PostFavorite(Base): __tablename__ = 'post_favorite' - post_id = Column( + post_id = sa.Column( 'post_id', - Integer, - ForeignKey('post.id'), + sa.Integer, + sa.ForeignKey('post.id'), primary_key=True, nullable=False, index=True) - user_id = Column( + user_id = sa.Column( 'user_id', - Integer, - ForeignKey('user.id'), + sa.Integer, + sa.ForeignKey('user.id'), primary_key=True, nullable=False, index=True) - time = Column('time', DateTime, nullable=False) + time = sa.Column('time', sa.DateTime, nullable=False) - post = relationship('Post') - user = relationship( + post = sa.orm.relationship('Post') + user = sa.orm.relationship( 'User', - backref=backref('post_favorites', cascade='all, delete-orphan')) + backref=sa.orm.backref('post_favorites', cascade='all, delete-orphan')) class PostNote(Base): __tablename__ = 'post_note' - post_note_id = Column('id', Integer, primary_key=True) - post_id = Column( - 'post_id', Integer, ForeignKey('post.id'), nullable=False, index=True) - polygon = Column('polygon', PickleType, nullable=False) - text = Column('text', UnicodeText, nullable=False) + post_note_id = sa.Column('id', sa.Integer, primary_key=True) + post_id = sa.Column( + 'post_id', + sa.Integer, + sa.ForeignKey('post.id'), + nullable=False, + index=True) + polygon = sa.Column('polygon', sa.PickleType, nullable=False) + text = sa.Column('text', sa.UnicodeText, nullable=False) - post = relationship('Post') + post = sa.orm.relationship('Post') class PostRelation(Base): __tablename__ = 'post_relation' - parent_id = Column( + parent_id = sa.Column( 'parent_id', - Integer, - ForeignKey('post.id'), + sa.Integer, + sa.ForeignKey('post.id'), primary_key=True, nullable=False, index=True) - child_id = Column( + child_id = sa.Column( 'child_id', - Integer, - ForeignKey('post.id'), + sa.Integer, + sa.ForeignKey('post.id'), primary_key=True, nullable=False, index=True) @@ -111,17 +121,17 @@ class PostRelation(Base): class PostTag(Base): __tablename__ = 'post_tag' - post_id = Column( + post_id = sa.Column( 'post_id', - Integer, - ForeignKey('post.id'), + sa.Integer, + sa.ForeignKey('post.id'), primary_key=True, nullable=False, index=True) - tag_id = Column( + tag_id = sa.Column( 'tag_id', - Integer, - ForeignKey('tag.id'), + sa.Integer, + sa.ForeignKey('tag.id'), primary_key=True, nullable=False, index=True) @@ -146,111 +156,123 @@ class Post(Base): FLAG_LOOP = 'loop' # basic meta - post_id = Column('id', Integer, primary_key=True) - user_id = Column( + post_id = sa.Column('id', sa.Integer, primary_key=True) + user_id = sa.Column( 'user_id', - Integer, - ForeignKey('user.id', ondelete='SET NULL'), + sa.Integer, + sa.ForeignKey('user.id', ondelete='SET NULL'), nullable=True, index=True) - version = Column('version', Integer, default=1, nullable=False) - creation_time = Column('creation_time', DateTime, nullable=False) - last_edit_time = Column('last_edit_time', DateTime) - safety = Column('safety', Unicode(32), nullable=False) - source = Column('source', Unicode(200)) - flags = Column('flags', PickleType, default=None) + version = sa.Column('version', sa.Integer, default=1, nullable=False) + creation_time = sa.Column('creation_time', sa.DateTime, nullable=False) + last_edit_time = sa.Column('last_edit_time', sa.DateTime) + safety = sa.Column('safety', sa.Unicode(32), nullable=False) + source = sa.Column('source', sa.Unicode(200)) + flags = sa.Column('flags', sa.PickleType, default=None) # content description - type = Column('type', Unicode(32), nullable=False) - checksum = Column('checksum', Unicode(64), nullable=False) - file_size = Column('file_size', Integer) - canvas_width = Column('image_width', Integer) - canvas_height = Column('image_height', Integer) - mime_type = Column('mime-type', Unicode(32), nullable=False) + type = sa.Column('type', sa.Unicode(32), nullable=False) + checksum = sa.Column('checksum', sa.Unicode(64), nullable=False) + file_size = sa.Column('file_size', sa.Integer) + canvas_width = sa.Column('image_width', sa.Integer) + canvas_height = sa.Column('image_height', sa.Integer) + mime_type = sa.Column('mime-type', sa.Unicode(32), nullable=False) # foreign tables - user = relationship('User') - tags = relationship('Tag', backref='posts', secondary='post_tag') - relations = relationship( + user = sa.orm.relationship('User') + tags = sa.orm.relationship('Tag', backref='posts', secondary='post_tag') + relations = sa.orm.relationship( 'Post', secondary='post_relation', primaryjoin=post_id == PostRelation.parent_id, secondaryjoin=post_id == PostRelation.child_id, lazy='joined', backref='related_by') - features = relationship( + features = sa.orm.relationship( 'PostFeature', cascade='all, delete-orphan', lazy='joined') - scores = relationship( + scores = sa.orm.relationship( 'PostScore', cascade='all, delete-orphan', lazy='joined') - favorited_by = relationship( + favorited_by = sa.orm.relationship( 'PostFavorite', cascade='all, delete-orphan', lazy='joined') - notes = relationship( + notes = sa.orm.relationship( 'PostNote', cascade='all, delete-orphan', lazy='joined') - comments = relationship('Comment', cascade='all, delete-orphan') + comments = sa.orm.relationship('Comment', cascade='all, delete-orphan') # dynamic columns - tag_count = column_property( - select([func.count(PostTag.tag_id)]) + tag_count = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.count(PostTag.tag_id)]) .where(PostTag.post_id == post_id) .correlate_except(PostTag)) - canvas_area = column_property(canvas_width * canvas_height) - canvas_aspect_ratio = column_property(canvas_width / canvas_height) + canvas_area = sa.orm.column_property(canvas_width * canvas_height) + canvas_aspect_ratio = sa.orm.column_property(canvas_width / canvas_height) @property def is_featured(self) -> bool: - featured_post = object_session(self) \ + featured_post = sa.orm.object_session(self) \ .query(PostFeature) \ .order_by(PostFeature.time.desc()) \ .first() return featured_post and featured_post.post_id == self.post_id - score = column_property( - select([func.coalesce(func.sum(PostScore.score), 0)]) + score = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.coalesce( + sa.sql.expression.func.sum(PostScore.score), 0)]) .where(PostScore.post_id == post_id) .correlate_except(PostScore)) - favorite_count = column_property( - select([func.count(PostFavorite.post_id)]) + favorite_count = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.count(PostFavorite.post_id)]) .where(PostFavorite.post_id == post_id) .correlate_except(PostFavorite)) - last_favorite_time = column_property( - select([func.max(PostFavorite.time)]) + last_favorite_time = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.max(PostFavorite.time)]) .where(PostFavorite.post_id == post_id) .correlate_except(PostFavorite)) - feature_count = column_property( - select([func.count(PostFeature.post_id)]) + feature_count = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.count(PostFeature.post_id)]) .where(PostFeature.post_id == post_id) .correlate_except(PostFeature)) - last_feature_time = column_property( - select([func.max(PostFeature.time)]) + last_feature_time = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.max(PostFeature.time)]) .where(PostFeature.post_id == post_id) .correlate_except(PostFeature)) - comment_count = column_property( - select([func.count(Comment.post_id)]) + comment_count = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.count(Comment.post_id)]) .where(Comment.post_id == post_id) .correlate_except(Comment)) - last_comment_creation_time = column_property( - select([func.max(Comment.creation_time)]) + last_comment_creation_time = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.max(Comment.creation_time)]) .where(Comment.post_id == post_id) .correlate_except(Comment)) - last_comment_edit_time = column_property( - select([func.max(Comment.last_edit_time)]) + last_comment_edit_time = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.max(Comment.last_edit_time)]) .where(Comment.post_id == post_id) .correlate_except(Comment)) - note_count = column_property( - select([func.count(PostNote.post_id)]) + note_count = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.count(PostNote.post_id)]) .where(PostNote.post_id == post_id) .correlate_except(PostNote)) - relation_count = column_property( - select([func.count(PostRelation.child_id)]) + relation_count = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.count(PostRelation.child_id)]) .where( (PostRelation.parent_id == post_id) | (PostRelation.child_id == post_id)) diff --git a/server/szurubooru/model/snapshot.py b/server/szurubooru/model/snapshot.py index beb3bb25..7f8bbdf0 100644 --- a/server/szurubooru/model/snapshot.py +++ b/server/szurubooru/model/snapshot.py @@ -1,6 +1,4 @@ -from sqlalchemy.orm import relationship -from sqlalchemy import ( - Column, Integer, DateTime, Unicode, PickleType, ForeignKey) +import sqlalchemy as sa from szurubooru.model.base import Base @@ -12,20 +10,20 @@ class Snapshot(Base): OPERATION_DELETED = 'deleted' OPERATION_MERGED = 'merged' - snapshot_id = Column('id', Integer, primary_key=True) - creation_time = Column('creation_time', DateTime, nullable=False) - operation = Column('operation', Unicode(16), nullable=False) - resource_type = Column( - 'resource_type', Unicode(32), nullable=False, index=True) - resource_pkey = Column( - 'resource_pkey', Integer, nullable=False, index=True) - resource_name = Column( - 'resource_name', Unicode(64), nullable=False) - user_id = Column( + snapshot_id = sa.Column('id', sa.Integer, primary_key=True) + creation_time = sa.Column('creation_time', sa.DateTime, nullable=False) + operation = sa.Column('operation', sa.Unicode(16), nullable=False) + resource_type = sa.Column( + 'resource_type', sa.Unicode(32), nullable=False, index=True) + resource_pkey = sa.Column( + 'resource_pkey', sa.Integer, nullable=False, index=True) + resource_name = sa.Column( + 'resource_name', sa.Unicode(64), nullable=False) + user_id = sa.Column( 'user_id', - Integer, - ForeignKey('user.id', ondelete='set null'), + sa.Integer, + sa.ForeignKey('user.id', ondelete='set null'), nullable=True) - data = Column('data', PickleType) + data = sa.Column('data', sa.PickleType) - user = relationship('User') + user = sa.orm.relationship('User') diff --git a/server/szurubooru/model/tag.py b/server/szurubooru/model/tag.py index 1bce3ffa..51059007 100644 --- a/server/szurubooru/model/tag.py +++ b/server/szurubooru/model/tag.py @@ -1,7 +1,4 @@ -from sqlalchemy import ( - Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey) -from sqlalchemy.orm import relationship, column_property -from sqlalchemy.sql.expression import func, select +import sqlalchemy as sa from szurubooru.model.base import Base from szurubooru.model.post import PostTag @@ -9,17 +6,17 @@ from szurubooru.model.post import PostTag class TagSuggestion(Base): __tablename__ = 'tag_suggestion' - parent_id = Column( + parent_id = sa.Column( 'parent_id', - Integer, - ForeignKey('tag.id'), + sa.Integer, + sa.ForeignKey('tag.id'), nullable=False, primary_key=True, index=True) - child_id = Column( + child_id = sa.Column( 'child_id', - Integer, - ForeignKey('tag.id'), + sa.Integer, + sa.ForeignKey('tag.id'), nullable=False, primary_key=True, index=True) @@ -32,17 +29,17 @@ class TagSuggestion(Base): class TagImplication(Base): __tablename__ = 'tag_implication' - parent_id = Column( + parent_id = sa.Column( 'parent_id', - Integer, - ForeignKey('tag.id'), + sa.Integer, + sa.ForeignKey('tag.id'), nullable=False, primary_key=True, index=True) - child_id = Column( + child_id = sa.Column( 'child_id', - Integer, - ForeignKey('tag.id'), + sa.Integer, + sa.ForeignKey('tag.id'), nullable=False, primary_key=True, index=True) @@ -55,11 +52,15 @@ class TagImplication(Base): class TagName(Base): __tablename__ = 'tag_name' - tag_name_id = Column('tag_name_id', Integer, primary_key=True) - tag_id = Column( - 'tag_id', Integer, ForeignKey('tag.id'), nullable=False, index=True) - name = Column('name', Unicode(64), nullable=False, unique=True) - order = Column('ord', Integer, nullable=False, index=True) + tag_name_id = sa.Column('tag_name_id', sa.Integer, primary_key=True) + tag_id = sa.Column( + 'tag_id', + sa.Integer, + sa.ForeignKey('tag.id'), + nullable=False, + index=True) + name = sa.Column('name', sa.Unicode(64), nullable=False, unique=True) + order = sa.Column('ord', sa.Integer, nullable=False, index=True) def __init__(self, name: str, order: int) -> None: self.name = name @@ -69,45 +70,46 @@ class TagName(Base): class Tag(Base): __tablename__ = 'tag' - tag_id = Column('id', Integer, primary_key=True) - category_id = Column( + tag_id = sa.Column('id', sa.Integer, primary_key=True) + category_id = sa.Column( 'category_id', - Integer, - ForeignKey('tag_category.id'), + sa.Integer, + sa.ForeignKey('tag_category.id'), nullable=False, index=True) - version = Column('version', Integer, default=1, nullable=False) - creation_time = Column('creation_time', DateTime, nullable=False) - last_edit_time = Column('last_edit_time', DateTime) - description = Column('description', UnicodeText, default=None) + version = sa.Column('version', sa.Integer, default=1, nullable=False) + creation_time = sa.Column('creation_time', sa.DateTime, nullable=False) + last_edit_time = sa.Column('last_edit_time', sa.DateTime) + description = sa.Column('description', sa.UnicodeText, default=None) - category = relationship('TagCategory', lazy='joined') - names = relationship( + category = sa.orm.relationship('TagCategory', lazy='joined') + names = sa.orm.relationship( 'TagName', cascade='all,delete-orphan', lazy='joined', order_by='TagName.order') - suggestions = relationship( + suggestions = sa.orm.relationship( 'Tag', secondary='tag_suggestion', primaryjoin=tag_id == TagSuggestion.parent_id, secondaryjoin=tag_id == TagSuggestion.child_id, lazy='joined') - implications = relationship( + implications = sa.orm.relationship( 'Tag', secondary='tag_implication', primaryjoin=tag_id == TagImplication.parent_id, secondaryjoin=tag_id == TagImplication.child_id, lazy='joined') - post_count = column_property( - select([func.count(PostTag.post_id)]) + post_count = sa.orm.column_property( + sa.sql.expression.select( + [sa.sql.expression.func.count(PostTag.post_id)]) .where(PostTag.tag_id == tag_id) .correlate_except(PostTag)) - first_name = column_property( + first_name = sa.orm.column_property( ( - select([TagName.name]) + sa.sql.expression.select([TagName.name]) .where(TagName.tag_id == tag_id) .order_by(TagName.order) .limit(1) @@ -115,17 +117,19 @@ class Tag(Base): ), deferred=True) - suggestion_count = column_property( + suggestion_count = sa.orm.column_property( ( - select([func.count(TagSuggestion.child_id)]) + sa.sql.expression.select( + [sa.sql.expression.func.count(TagSuggestion.child_id)]) .where(TagSuggestion.parent_id == tag_id) .as_scalar() ), deferred=True) - implication_count = column_property( + implication_count = sa.orm.column_property( ( - select([func.count(TagImplication.child_id)]) + sa.sql.expression.select( + [sa.sql.expression.func.count(TagImplication.child_id)]) .where(TagImplication.parent_id == tag_id) .as_scalar() ), diff --git a/server/szurubooru/model/tag_category.py b/server/szurubooru/model/tag_category.py index 001f9653..865cda9d 100644 --- a/server/szurubooru/model/tag_category.py +++ b/server/szurubooru/model/tag_category.py @@ -1,7 +1,5 @@ from typing import Optional -from sqlalchemy import Column, Integer, Unicode, Boolean, table -from sqlalchemy.orm import column_property -from sqlalchemy.sql.expression import func, select +import sqlalchemy as sa from szurubooru.model.base import Base from szurubooru.model.tag import Tag @@ -9,19 +7,20 @@ from szurubooru.model.tag import Tag class TagCategory(Base): __tablename__ = 'tag_category' - tag_category_id = Column('id', Integer, primary_key=True) - version = Column('version', Integer, default=1, nullable=False) - name = Column('name', Unicode(32), nullable=False) - color = Column('color', Unicode(32), nullable=False, default='#000000') - default = Column('default', Boolean, nullable=False, default=False) + tag_category_id = sa.Column('id', sa.Integer, primary_key=True) + version = sa.Column('version', sa.Integer, default=1, nullable=False) + name = sa.Column('name', sa.Unicode(32), nullable=False) + color = sa.Column( + 'color', sa.Unicode(32), nullable=False, default='#000000') + default = sa.Column('default', sa.Boolean, nullable=False, default=False) def __init__(self, name: Optional[str]=None) -> None: self.name = name - tag_count = column_property( - select([func.count('Tag.tag_id')]) + tag_count = sa.orm.column_property( + sa.sql.expression.select([sa.sql.expression.func.count('Tag.tag_id')]) .where(Tag.category_id == tag_category_id) - .correlate_except(table('Tag'))) + .correlate_except(sa.table('Tag'))) __mapper_args__ = { 'version_id_col': version, From ee6b66329b953194e56b2b4d3b7e53b750668dbf Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 23:20:00 +0100 Subject: [PATCH 040/159] server/posts: fix search by aspect ratio It was being rounded to nearest integer because of the width/height columns' data type. --- server/szurubooru/model/post.py | 4 +++- .../tests/search/configs/test_post_search_config.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/server/szurubooru/model/post.py b/server/szurubooru/model/post.py index d0de4855..a3af2372 100644 --- a/server/szurubooru/model/post.py +++ b/server/szurubooru/model/post.py @@ -205,7 +205,9 @@ class Post(Base): .correlate_except(PostTag)) canvas_area = sa.orm.column_property(canvas_width * canvas_height) - canvas_aspect_ratio = sa.orm.column_property(canvas_width / canvas_height) + canvas_aspect_ratio = sa.orm.column_property( + sa.sql.expression.func.cast(canvas_width, sa.Float) / + sa.sql.expression.func.cast(canvas_height, sa.Float)) @property def is_featured(self) -> bool: diff --git a/server/szurubooru/tests/search/configs/test_post_search_config.py b/server/szurubooru/tests/search/configs/test_post_search_config.py index 551e09c2..ef3cdd9d 100644 --- a/server/szurubooru/tests/search/configs/test_post_search_config.py +++ b/server/szurubooru/tests/search/configs/test_post_search_config.py @@ -431,23 +431,27 @@ def test_filter_by_file_size( ('image-area:90000', [3]), ('image-area:20000,90000', [1, 2, 3]), ('image-ar:1', [3]), - ('image-ar:..0.9', [1]), + ('image-ar:..0.9', [1, 4]), ('image-ar:1.1..', [2]), ('image-ar:1/1..1/1', [3]), ('image-ar:1:1..1:1', [3]), + ('image-ar:0.62..0.63', [4]), ]) def test_filter_by_image_size( verify_unpaged, post_factory, input, expected_post_ids): post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) + post4 = post_factory(id=4) post1.canvas_width = 100 post1.canvas_height = 200 post2.canvas_width = 200 post2.canvas_height = 100 post3.canvas_width = 300 post3.canvas_height = 300 - db.session.add_all([post1, post2, post3]) + post4.canvas_width = 480 + post4.canvas_height = 767 + db.session.add_all([post1, post2, post3, post4]) db.session.flush() verify_unpaged(input, expected_post_ids) From 72056e0cd29635f0ca5afafc874e715ac5866d4f Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 23:27:59 +0100 Subject: [PATCH 041/159] server/requirements: fix skimage package name... Brain fart during previous commit... --- server/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/requirements.txt b/server/requirements.txt index 7b6d53cf..2cd15ec1 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -10,4 +10,4 @@ pycodestyle>=2.0.0 scipy>=0.18.1 elasticsearch>=5.0.0 elasticsearch-dsl>=5.0.0 -skimage>=0.12 +scikit-image>=0.12 From 74c583f11ddb5bd96d6736d59f1b84674b7985af Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 5 Feb 2017 23:29:21 +0100 Subject: [PATCH 042/159] server/build: fix alembic environment script --- server/szurubooru/migrations/env.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/szurubooru/migrations/env.py b/server/szurubooru/migrations/env.py index a4257d48..7065a69e 100644 --- a/server/szurubooru/migrations/env.py +++ b/server/szurubooru/migrations/env.py @@ -9,7 +9,7 @@ import logging.config dir_to_self = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.join(dir_to_self, *[os.pardir] * 2)) -import szurubooru.db.base +import szurubooru.model.base import szurubooru.config alembic_config = alembic.context.config @@ -18,7 +18,7 @@ logging.config.fileConfig(alembic_config.config_file_name) szuru_config = szurubooru.config.config alembic_config.set_main_option('sqlalchemy.url', szuru_config['database']) -target_metadata = szurubooru.db.Base.metadata +target_metadata = szurubooru.model.Base.metadata def run_migrations_offline(): @@ -51,7 +51,7 @@ def run_migrations_online(): connectable = sa.engine_from_config( alembic_config.get_section(alembic_config.config_ini_section), prefix='sqlalchemy.', - poolclass=sqlalchemy.pool.NullPool) + poolclass=sa.pool.NullPool) with connectable.connect() as connection: alembic.context.configure( From 7f09306dde6886082e7bf78aa465b7954b49ef8e Mon Sep 17 00:00:00 2001 From: rr- Date: Tue, 7 Feb 2017 18:03:35 +0100 Subject: [PATCH 043/159] server/api: fix unicode urls (#121) --- server/szurubooru/rest/app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/szurubooru/rest/app.py b/server/szurubooru/rest/app.py index b29110e7..ea2a2877 100644 --- a/server/szurubooru/rest/app.py +++ b/server/szurubooru/rest/app.py @@ -33,6 +33,7 @@ def _get_headers(env: Dict[str, Any]) -> Dict[str, str]: def _create_context(env: Dict[str, Any]) -> context.Context: method = env['REQUEST_METHOD'] path = '/' + env['PATH_INFO'].lstrip('/') + path = path.encode('latin-1').decode('utf-8') # PEP-3333 headers = _get_headers(env) files = {} From a3b3532ca4e953610d0ed4235dc63ef6d6c3a631 Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Tue, 7 Feb 2017 20:23:47 +0100 Subject: [PATCH 044/159] server/api: patch timing attack on password reset form --- server/szurubooru/api/password_reset_api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/szurubooru/api/password_reset_api.py b/server/szurubooru/api/password_reset_api.py index f49080a9..da3effd4 100644 --- a/server/szurubooru/api/password_reset_api.py +++ b/server/szurubooru/api/password_reset_api.py @@ -1,6 +1,7 @@ from typing import Dict from szurubooru import config, errors, rest from szurubooru.func import auth, mailer, users, versions +from hashlib import md5 MAIL_SUBJECT = 'Password reset for {name}' @@ -30,6 +31,10 @@ def start_password_reset( return {} +def _hash(token: str) -> str: + return md5(token.encode('utf-8')).hexdigest() + + @rest.routes.post('/password-reset/(?P[^/]+)/?') def finish_password_reset( ctx: rest.Context, params: Dict[str, str]) -> rest.Response: @@ -37,7 +42,7 @@ def finish_password_reset( user = users.get_user_by_name_or_email(user_name) good_token = auth.generate_authentication_token(user) token = ctx.get_param_as_string('token') - if token != good_token: + if _hash(token) != _hash(good_token): raise errors.ValidationError('Invalid password reset token.') new_password = users.reset_user_password(user) versions.bump_version(user) From ba7ca0cd8727b5397729983d052636637f36dd59 Mon Sep 17 00:00:00 2001 From: rr- Date: Tue, 7 Feb 2017 21:33:44 +0100 Subject: [PATCH 045/159] client/tags: use new color input (#119) --- client/css/core-forms.styl | 20 ++++++++++++--- client/css/tag-categories-view.styl | 5 ++-- client/js/util/views.js | 39 +++++++++++++++++++++-------- 3 files changed, 48 insertions(+), 16 deletions(-) diff --git a/client/css/core-forms.styl b/client/css/core-forms.styl index 8323d16e..5b62e0c1 100644 --- a/client/css/core-forms.styl +++ b/client/css/core-forms.styl @@ -170,13 +170,25 @@ input:disabled cursor: not-allowed label.color + white-space: nowrap position: relative + display: flex input[type=text] + margin-right: 0.25em + width: auto + .preview + display: inline-block text-align: center - pointer-events: none - input[type=color] - position: absolute - opacity: 0 + padding: 0 0.5em + border: 2px solid black + &:after + content: 'A' + .background-preview + border-right: 0 + color: transparent + .text-preview + border-left: 0 + form.show-validation .input input:invalid diff --git a/client/css/tag-categories-view.styl b/client/css/tag-categories-view.styl index 31c58380..d5f07627 100644 --- a/client/css/tag-categories-view.styl +++ b/client/css/tag-categories-view.styl @@ -2,7 +2,7 @@ .content-wrapper.tag-categories width: 100% - max-width: 40em + max-width: 45em table border-spacing: 0 width: 100% @@ -11,7 +11,8 @@ td, th padding: .4em &.color - text-align: center + input[type=text] + width: 8em &.usages text-align: center &.remove, &.set-default diff --git a/client/js/util/views.js b/client/js/util/views.js index 7f062599..db5ae4b0 100644 --- a/client/js/util/views.js +++ b/client/js/util/views.js @@ -139,12 +139,29 @@ function makeColorInput(options) { type: 'text', value: options.value || '', required: options.required, - style: 'color: ' + options.value, - disabled: true, + class: 'color', }); - const colorInput = makeElement( - 'input', {type: 'color', value: options.value || ''}); - return makeElement('label', {class: 'color'}, colorInput, textInput); + const backgroundPreviewNode = makeElement( + 'div', + { + class: 'preview background-preview', + style: + `border-color: ${options.value}; + background-color: ${options.value}`, + }); + const textPreviewNode = makeElement( + 'div', + { + class: 'preview text-preview', + style: + `border-color: ${options.value}; + color: ${options.value}`, + }); + return makeElement( + 'label', {class: 'color'}, + textInput, + backgroundPreviewNode, + textPreviewNode); } function makeNumericInput(options) { @@ -478,11 +495,13 @@ function monitorNodeRemoval(monitoredNode, callback) { } document.addEventListener('input', e => { - const type = e.target.getAttribute('type'); - if (type && type.toLowerCase() === 'color') { - const textInput = e.target.parentNode.querySelector('input[type=text]'); - textInput.style.color = e.target.value; - textInput.value = e.target.value; + if (e.target.classList.contains('color')) { + let bkNode = e.target.parentNode.querySelector('.background-preview'); + let textNode = e.target.parentNode.querySelector('.text-preview'); + bkNode.style.backgroundColor = e.target.value; + bkNode.style.borderColor = e.target.value; + textNode.style.color = e.target.value; + textNode.style.borderColor = e.target.value; } }); From fdad08e176b52f1b47930a5e9b22116c70c542a8 Mon Sep 17 00:00:00 2001 From: rr- Date: Thu, 9 Feb 2017 00:48:06 +0100 Subject: [PATCH 046/159] server: use index-based paging (#123) --- API.md | 18 ++-- client/html/comments_page.tpl | 2 +- client/html/manual_pager_nav.tpl | 12 +-- client/html/posts_page.tpl | 4 +- client/html/snapshots_page.tpl | 4 +- client/html/tags_page.tpl | 4 +- client/html/users_page.tpl | 2 +- client/js/controllers/comments_controller.js | 10 +- client/js/controllers/page_controller.js | 6 -- client/js/controllers/post_list_controller.js | 15 +-- client/js/controllers/post_main_controller.js | 2 +- client/js/controllers/snapshots_controller.js | 9 +- client/js/controllers/tag_list_controller.js | 9 +- client/js/controllers/user_list_controller.js | 10 +- client/js/models/post_list.js | 6 +- client/js/models/snapshot_list.js | 4 +- client/js/models/tag_list.js | 6 +- client/js/models/user_list.js | 4 +- client/js/router.js | 4 - client/js/views/comments_page_view.js | 2 +- client/js/views/endless_page_view.js | 91 +++++++++++-------- client/js/views/manual_page_view.js | 81 ++++++++++------- client/js/views/posts_header_view.js | 9 +- client/js/views/posts_page_view.js | 2 +- server/szurubooru/search/executor.py | 21 ++--- .../tests/api/test_comment_retrieving.py | 8 +- .../tests/api/test_post_retrieving.py | 16 ++-- .../tests/api/test_snapshot_retrieving.py | 8 +- .../tests/api/test_tag_retrieving.py | 8 +- .../tests/api/test_user_retrieving.py | 4 +- .../configs/test_comment_search_config.py | 2 +- .../search/configs/test_post_search_config.py | 8 +- .../search/configs/test_tag_search_config.py | 4 +- .../search/configs/test_user_search_config.py | 20 ++-- 34 files changed, 222 insertions(+), 193 deletions(-) diff --git a/API.md b/API.md index 8b25a516..abdc8425 100644 --- a/API.md +++ b/API.md @@ -404,7 +404,7 @@ data. ## Listing tags - **Request** - `GET /tags/?page=&pageSize=&query=` + `GET /tags/?offset=&limit=&query=` - **Output** @@ -675,7 +675,7 @@ data. ## Listing posts - **Request** - `GET /posts/?page=&pageSize=&query=` + `GET /posts/?offset=&limit=&query=` - **Output** @@ -1102,7 +1102,7 @@ data. ## Listing comments - **Request** - `GET /comments/?page=&pageSize=&query=` + `GET /comments/?offset=&limit=&query=` - **Output** @@ -1291,7 +1291,7 @@ data. ## Listing users - **Request** - `GET /users/?page=&pageSize=&query=` + `GET /users/?offset=&limit=&query=` - **Output** @@ -1539,7 +1539,7 @@ data. ## Listing snapshots - **Request** - `GET /snapshots/?page=&pageSize=&query=` + `GET /snapshots/?offset=&limit=&query=` - **Output** @@ -2166,9 +2166,9 @@ A result of search operation that involves paging. ```json5 { - "query": , // same as in input - "page": , // same as in input - "pageSize": , + "query": , // same as in input + "offset": , // same as in input + "limit": , "total": , "results": [ , @@ -2181,7 +2181,7 @@ A result of search operation that involves paging. **Field meaning** - ``: the query passed in the original request that contains standard [search query](#search). -- ``: the page number, passed in the original request. +- ``: the record starting offset, passed in the original request. - ``: number of records on one page. - ``: how many resources were found. To get the page count, divide this number by ``. diff --git a/client/html/comments_page.tpl b/client/html/comments_page.tpl index b117bf96..27d7011d 100644 --- a/client/html/comments_page.tpl +++ b/client/html/comments_page.tpl @@ -1,6 +1,6 @@
    diff --git a/client/html/posts_page.tpl b/client/html/posts_page.tpl index a6c6cc55..895de559 100644 --- a/client/html/posts_page.tpl +++ b/client/html/posts_page.tpl @@ -2,7 +2,7 @@ <% if (ctx.response.results.length) { %>
      <% for (let post of ctx.response.results) { %> -
    • +
    • ' href='<%= ctx.canViewPosts ? ctx.getPostUrl(post.id, ctx.parameters) : '' %>'> @@ -35,9 +35,17 @@ <% if (ctx.canBulkEditTags && ctx.parameters && ctx.parameters.tag) { %> - + <% } %> + <% if (ctx.canBulkEditSafety && ctx.parameters && ctx.parameters.safety) { %> + + <% for (let safety of ['safe', 'sketchy', 'unsafe']) { %> + '> + + <% } %> + + <% } %>
    • <% } %> diff --git a/client/js/controllers/post_list_controller.js b/client/js/controllers/post_list_controller.js index a0973e19..12b354f9 100644 --- a/client/js/controllers/post_list_controller.js +++ b/client/js/controllers/post_list_controller.js @@ -11,7 +11,7 @@ const PostsPageView = require('../views/posts_page_view.js'); const EmptyView = require('../views/empty_view.js'); const fields = [ - 'id', 'thumbnailUrl', 'type', + 'id', 'thumbnailUrl', 'type', 'safety', 'score', 'favoriteCount', 'commentCount', 'tags', 'version']; class PostListController { @@ -32,6 +32,7 @@ class PostListController { hostNode: this._pageController.view.pageHeaderHolderNode, parameters: ctx.parameters, canBulkEditTags: api.hasPrivilege('posts:bulkEdit:tags'), + canBulkEditSafety: api.hasPrivilege('posts:bulkEdit:safety'), bulkEdit: { tags: this._bulkEditTags }, @@ -73,6 +74,11 @@ class PostListController { e.detail.post.save().catch(error => window.alert(error.message)); } + _evtChangeSafety(e) { + e.detail.post.safety = e.detail.safety; + e.detail.post.save().catch(error => window.alert(error.message)); + } + _decorateSearchQuery(text) { const browsingSettings = settings.get(); let disabledSafety = []; @@ -106,6 +112,8 @@ class PostListController { Object.assign(pageCtx, { canViewPosts: api.hasPrivilege('posts:view'), canBulkEditTags: api.hasPrivilege('posts:bulkEdit:tags'), + canBulkEditSafety: + api.hasPrivilege('posts:bulkEdit:safety'), bulkEdit: { tags: this._bulkEditTags, }, @@ -113,6 +121,8 @@ class PostListController { const view = new PostsPageView(pageCtx); view.addEventListener('tag', e => this._evtTag(e)); view.addEventListener('untag', e => this._evtUntag(e)); + view.addEventListener( + 'changeSafety', e => this._evtChangeSafety(e)); return view; }, }); diff --git a/client/js/views/posts_header_view.js b/client/js/views/posts_header_view.js index 0bc1dce9..efa3d02c 100644 --- a/client/js/views/posts_header_view.js +++ b/client/js/views/posts_header_view.js @@ -11,26 +11,19 @@ const TagAutoCompleteControl = const template = views.getTemplate('posts-header'); -class BulkTagEditor extends events.EventTarget { +class BulkEditor extends events.EventTarget { constructor(hostNode) { super(); this._hostNode = hostNode; - - this._autoCompleteControl = new TagAutoCompleteControl( - this._inputNode, {addSpace: false}); this._openLinkNode.addEventListener( 'click', e => this._evtOpenLinkClick(e)); this._closeLinkNode.addEventListener( 'click', e => this._evtCloseLinkClick(e)); - this._hostNode.addEventListener('submit', e => this._evtFormSubmit(e)); - } - - get value() { - return this._inputNode.value; } get opened() { - return this._hostNode.classList.contains('opened'); + return this._hostNode.classList.contains('opened') && + !this._hostNode.classList.contains('hidden'); } get _openLinkNode() { @@ -41,6 +34,53 @@ class BulkTagEditor extends events.EventTarget { return this._hostNode.querySelector('.close'); } + toggleOpen(state) { + this._hostNode.classList.toggle('opened', state); + } + + toggleHide(state) { + this._hostNode.classList.toggle('hidden', state); + } + + _evtOpenLinkClick(e) { + throw new Error('Not implemented'); + } + + _evtCloseLinkClick(e) { + throw new Error('Not implemented'); + } +} + +class BulkSafetyEditor extends BulkEditor { + constructor(hostNode) { + super(hostNode); + } + + _evtOpenLinkClick(e) { + e.preventDefault(); + this.toggleOpen(true); + this.dispatchEvent(new CustomEvent('open', {detail: {}})); + } + + _evtCloseLinkClick(e) { + e.preventDefault(); + this.toggleOpen(false); + this.dispatchEvent(new CustomEvent('close', {detail: {}})); + } +} + +class BulkTagEditor extends BulkEditor { + constructor(hostNode) { + super(hostNode); + this._autoCompleteControl = new TagAutoCompleteControl( + this._inputNode, {addSpace: false}); + this._hostNode.addEventListener('submit', e => this._evtFormSubmit(e)); + } + + get value() { + return this._inputNode.value; + } + get _inputNode() { return this._hostNode.querySelector('input[name=tag]'); } @@ -54,10 +94,6 @@ class BulkTagEditor extends events.EventTarget { this._inputNode.blur(); } - toggleOpen(state) { - this._hostNode.classList.toggle('opened', state); - } - _evtFormSubmit(e) { e.preventDefault(); this.dispatchEvent(new CustomEvent('submit', {detail: {}})); @@ -99,18 +135,38 @@ class PostsHeaderView extends events.EventTarget { safetyButtonNode.addEventListener( 'click', e => this._evtSafetyButtonClick(e)); } - this._formNode.addEventListener( - 'submit', e => this._evtFormSubmit(e)); + this._formNode.addEventListener('submit', e => this._evtFormSubmit(e)); + this._bulkEditors = []; if (this._bulkEditTagsNode) { this._bulkTagEditor = new BulkTagEditor(this._bulkEditTagsNode); - this._bulkTagEditor.toggleOpen(!!ctx.parameters.tag); + this._bulkEditors.push(this._bulkTagEditor); + } + + if (this._bulkEditSafetyNode) { + this._bulkSafetyEditor = new BulkSafetyEditor( + this._bulkEditSafetyNode); + this._bulkEditors.push(this._bulkSafetyEditor); + } + + for (let editor of this._bulkEditors) { this._bulkTagEditor.addEventListener('submit', e => { this._navigate(); }); - this._bulkTagEditor.addEventListener('close', e => { + editor.addEventListener('open', e => { + this._hideBulkEditorsExcept(editor); this._navigate(); }); + editor.addEventListener('close', e => { + this._closeAndShowAllBulkEditors(); + this._navigate(); + }); + } + + if (ctx.parameters.tag && this._bulkTagEditor) { + this._openBulkEditor(this._bulkTagEditor); + } else if (ctx.parameters.safety && this._bulkSafetyEditor) { + this._openBulkEditor(this._bulkSafetyEditor); } } @@ -130,6 +186,31 @@ class PostsHeaderView extends events.EventTarget { return this._hostNode.querySelector('.bulk-edit-tags'); } + get _bulkEditSafetyNode() { + return this._hostNode.querySelector('.bulk-edit-safety'); + } + + _openBulkEditor(editor) { + editor.toggleOpen(true); + this._hideBulkEditorsExcept(editor); + } + + _hideBulkEditorsExcept(editor) { + for (let otherEditor of this._bulkEditors) { + if (otherEditor !== editor) { + otherEditor.toggleOpen(false); + otherEditor.toggleHide(true); + } + } + } + + _closeAndShowAllBulkEditors() { + for (let otherEditor of this._bulkEditors) { + otherEditor.toggleOpen(false); + otherEditor.toggleHide(false); + } + } + _evtSafetyButtonClick(e, url) { e.preventDefault(); e.target.classList.toggle('disabled'); @@ -164,6 +245,9 @@ class PostsHeaderView extends events.EventTarget { } else { parameters.tag = null; } + parameters.safety = ( + this._bulkSafetyEditor && + this._bulkSafetyEditor.opened ? '1' : null); this.dispatchEvent( new CustomEvent('navigate', {detail: {parameters: parameters}})); } diff --git a/client/js/views/posts_page_view.js b/client/js/views/posts_page_view.js index c2dbe904..26103251 100644 --- a/client/js/views/posts_page_view.js +++ b/client/js/views/posts_page_view.js @@ -18,26 +18,48 @@ class PostsPageView extends events.EventTarget { post.addEventListener('change', e => this._evtPostChange(e)); } - this._postIdToLinkNode = {}; - for (let linkNode of this._tagFlipperNodes) { - const postId = linkNode.getAttribute('data-post-id'); + this._postIdToListItemNode = {}; + for (let listItemNode of this._listItemNodes) { + const postId = listItemNode.getAttribute('data-post-id'); const post = this._postIdToPost[postId]; - this._postIdToLinkNode[postId] = linkNode; - linkNode.addEventListener( - 'click', e => this._evtBulkEditTagsClick(e, post)); + this._postIdToListItemNode[postId] = listItemNode; + + const tagFlipperNode = this._getTagFlipperNode(listItemNode); + if (tagFlipperNode) { + tagFlipperNode.addEventListener( + 'click', e => this._evtBulkEditTagsClick(e, post)); + } + + const safetyFlipperNode = this._getSafetyFlipperNode(listItemNode); + if (safetyFlipperNode) { + for (let linkNode of safetyFlipperNode.querySelectorAll('a')) { + linkNode.addEventListener( + 'click', e => this._evtBulkEditSafetyClick(e, post)); + } + } } - this._syncTagFlippersHighlights(); + this._syncBulkEditorsHighlights(); } - get _tagFlipperNodes() { - return this._hostNode.querySelectorAll('.tag-flipper'); + get _listItemNodes() { + return this._hostNode.querySelectorAll('li'); + } + + _getTagFlipperNode(listItemNode) { + return listItemNode.querySelector('.tag-flipper'); + } + + _getSafetyFlipperNode(listItemNode) { + return listItemNode.querySelector('.safety-flipper'); } _evtPostChange(e) { - const linkNode = this._postIdToLinkNode[e.detail.post.id]; - linkNode.removeAttribute('data-disabled'); - this._syncTagFlippersHighlights(); + const listItemNode = this._postIdToListItemNode[e.detail.post.id]; + for (let node of listItemNode.querySelectorAll('[data-disabled]')) { + node.removeAttribute('data-disabled'); + } + this._syncBulkEditorsHighlights(); } _evtBulkEditTagsClick(e, post) { @@ -53,15 +75,43 @@ class PostsPageView extends events.EventTarget { {detail: {post: post}})); } - _syncTagFlippersHighlights() { - for (let linkNode of this._tagFlipperNodes) { - const postId = linkNode.getAttribute('data-post-id'); + _evtBulkEditSafetyClick(e, post) { + e.preventDefault(); + const linkNode = e.target; + if (linkNode.getAttribute('data-disabled')) { + return; + } + const newSafety = linkNode.getAttribute('data-safety'); + if (post.safety === newSafety) { + return; + } + linkNode.setAttribute('data-disabled', true); + this.dispatchEvent( + new CustomEvent( + 'changeSafety', {detail: {post: post, safety: newSafety}})); + } + + _syncBulkEditorsHighlights() { + for (let listItemNode of this._listItemNodes) { + const postId = listItemNode.getAttribute('data-post-id'); const post = this._postIdToPost[postId]; - let tagged = true; - for (let tag of this._ctx.bulkEdit.tags) { - tagged = tagged & post.isTaggedWith(tag); + + const tagFlipperNode = this._getTagFlipperNode(listItemNode); + if (tagFlipperNode) { + let tagged = true; + for (let tag of this._ctx.bulkEdit.tags) { + tagged = tagged & post.isTaggedWith(tag); + } + tagFlipperNode.classList.toggle('tagged', tagged); + } + + const safetyFlipperNode = this._getSafetyFlipperNode(listItemNode); + if (safetyFlipperNode) { + for (let linkNode of safetyFlipperNode.querySelectorAll('a')) { + const safety = linkNode.getAttribute('data-safety'); + linkNode.classList.toggle('active', post.safety == safety); + } } - linkNode.classList.toggle('tagged', tagged); } } } diff --git a/config.yaml.dist b/config.yaml.dist index 2b03c0c8..9d50e50c 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -90,6 +90,7 @@ privileges: 'posts:merge': moderator 'posts:favorite': regular 'posts:bulk-edit:tags': power + 'posts:bulk-edit:safety': power 'tags:create': regular 'tags:edit:names': power From aa1f4d3ff8061db07111ab006a82edeac2ef7b18 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 12 Feb 2017 10:40:50 +0100 Subject: [PATCH 051/159] client/posts: add file extensions info to upload --- client/css/post-upload.styl | 2 ++ client/html/file_dropper.tpl | 4 ++++ client/js/controls/file_dropper_control.js | 7 ++++--- client/js/views/post_upload_view.js | 2 ++ 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/client/css/post-upload.styl b/client/css/post-upload.styl index 4da53a77..147df6bc 100644 --- a/client/css/post-upload.styl +++ b/client/css/post-upload.styl @@ -22,6 +22,8 @@ $cancel-button-color = tomato .file-dropper font-size: 150% padding: 2em + small + font-size: 60% input[type=submit] margin-top: 1em diff --git a/client/html/file_dropper.tpl b/client/html/file_dropper.tpl index 9c662010..9e7715d5 100644 --- a/client/html/file_dropper.tpl +++ b/client/html/file_dropper.tpl @@ -8,6 +8,10 @@ <% } %>
      Or just click on this box. + <% if (ctx.extraText) { %> +
      + <%= ctx.extraText %> + <% } %> <% if (ctx.allowUrls) { %> diff --git a/client/js/controls/file_dropper_control.js b/client/js/controls/file_dropper_control.js index 113e7aca..7b3c1209 100644 --- a/client/js/controls/file_dropper_control.js +++ b/client/js/controls/file_dropper_control.js @@ -11,8 +11,9 @@ class FileDropperControl extends events.EventTarget { this._options = options; const source = template({ - allowMultiple: this._options.allowMultiple, - allowUrls: this._options.allowUrls, + extraText: options.extraText, + allowMultiple: options.allowMultiple, + allowUrls: options.allowUrls, id: 'file-' + Math.random().toString(36).substring(7), }); @@ -21,7 +22,7 @@ class FileDropperControl extends events.EventTarget { this._urlConfirmButtonNode = source.querySelector('button'); this._fileInputNode = source.querySelector('input[type=file]'); this._fileInputNode.style.display = 'none'; - this._fileInputNode.multiple = this._options.allowMultiple || false; + this._fileInputNode.multiple = options.allowMultiple || false; this._counter = 0; this._dropperNode.addEventListener( diff --git a/client/js/views/post_upload_view.js b/client/js/views/post_upload_view.js index 144526be..f80135f4 100644 --- a/client/js/views/post_upload_view.js +++ b/client/js/views/post_upload_view.js @@ -156,6 +156,8 @@ class PostUploadView extends events.EventTarget { this._contentFileDropper = new FileDropperControl( this._contentInputNode, { + extraText: + 'Allowed extensions: .jpg, .png, .gif, .webm, .mp4, .swf', allowUrls: true, allowMultiple: true, lock: false, From 32d15a493c713dcfbe56bafb4e8ce92a5ee40df1 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 12 Feb 2017 10:41:49 +0100 Subject: [PATCH 052/159] client/css: add margin to file dropper button --- client/css/core-forms.styl | 1 + 1 file changed, 1 insertion(+) diff --git a/client/css/core-forms.styl b/client/css/core-forms.styl index 541de1bc..06a54421 100644 --- a/client/css/core-forms.styl +++ b/client/css/core-forms.styl @@ -259,6 +259,7 @@ input::-moz-focus-inner word-wrap: break-word input margin-top: 0.5em + margin-right: 0.5em width: auto flex: 1 button From c01214e919d0872d7b7644b2ec9d7472d8c6bbc9 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 17 Feb 2017 23:08:37 +0100 Subject: [PATCH 053/159] server/password-reset: support having no smtp --- client/css/password-reset.styl | 2 ++ client/html/login.tpl | 4 +-- client/html/password_reset.tpl | 43 +++++++++++++++----------- client/js/views/password_reset_view.js | 6 +++- config.yaml.dist | 14 +++++++-- 5 files changed, 45 insertions(+), 24 deletions(-) create mode 100644 client/css/password-reset.styl diff --git a/client/css/password-reset.styl b/client/css/password-reset.styl new file mode 100644 index 00000000..47e32f3f --- /dev/null +++ b/client/css/password-reset.styl @@ -0,0 +1,2 @@ +#password-reset + max-width: 30em diff --git a/client/html/login.tpl b/client/html/login.tpl index 8ccc439d..186a5489 100644 --- a/client/html/login.tpl +++ b/client/html/login.tpl @@ -30,9 +30,7 @@
      - <% if (ctx.canSendMails) { %> - '>Forgot the password? - <% } %> + '>Forgot the password?
    diff --git a/client/html/password_reset.tpl b/client/html/password_reset.tpl index b50b48ee..5671379e 100644 --- a/client/html/password_reset.tpl +++ b/client/html/password_reset.tpl @@ -1,23 +1,30 @@

    Password reset

    -
    -
      -
    • - <%= ctx.makeTextInput({ - text: 'User name or e-mail address', - name: 'user-name', - required: true, - }) %> -
    • -
    + <% if (ctx.canSendMails) { %> + +
      +
    • + <%= ctx.makeTextInput({ + text: 'User name or e-mail address', + name: 'user-name', + required: true, + }) %> +
    • +
    -

    Proceeding will send an e-mail that contains a password reset - link. Clicking it is going to generate a new password for your account. - It is recommended to change that password to something else.

    +

    Proceeding will send an e-mail that contains a password reset + link. Clicking it is going to generate a new password for your account. + It is recommended to change that password to something else.

    -
    -
    - -
    -
    +
    +
    + +
    + + <% } else { %> +

    We do not support automatic password resetting.

    + <% if (ctx.contactEmail) { %> +

    Please send an e-mail to <%- ctx.contactEmail %> to go through a manual procedure.

    + <% } %> + <% } %>
    diff --git a/client/js/views/password_reset_view.js b/client/js/views/password_reset_view.js index 38ac719b..25409c77 100644 --- a/client/js/views/password_reset_view.js +++ b/client/js/views/password_reset_view.js @@ -1,5 +1,6 @@ 'use strict'; +const config = require('../config.js'); const events = require('../events.js'); const views = require('../util/views.js'); @@ -10,7 +11,10 @@ class PasswordResetView extends events.EventTarget { super(); this._hostNode = document.getElementById('content-holder'); - views.replaceContent(this._hostNode, template()); + views.replaceContent(this._hostNode, template({ + canSendMails: config.canSendMails, + contactEmail: config.contactEmail, + })); views.syncScrollPosition(); views.decorateValidator(this._formNode); diff --git a/config.yaml.dist b/config.yaml.dist index 9d50e50c..3b00d0eb 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -25,12 +25,18 @@ thumbnails: post_width: 300 post_height: 300 -# used to send password reminders + +# used to send password reset e-mails smtp: host: # example: localhost port: # example: 25 user: # example: bot pass: # example: groovy123 + # host can be left empty, in which case it is recommended to fill contactEmail. + + +contactEmail: # example: bob@example.com. Meant for manual password reset procedures + # used for reverse image search elasticsearch: @@ -38,15 +44,16 @@ elasticsearch: port: 9200 index: szurubooru + limits: users_per_page: 20 posts_per_page: 40 max_comment_length: 5000 + tag_name_regex: ^\S+$ tag_category_name_regex: ^[^\s%+#/]+$ -default_rank: regular # don't change these, unless you want to annoy people. if you do customize # them though, make sure to update the instructions in the registration form @@ -54,6 +61,9 @@ default_rank: regular password_regex: '^.{5,}$' user_name_regex: '^[a-zA-Z0-9_-]{1,32}$' +default_rank: regular + + privileges: 'users:create': anonymous 'users:list': regular From 33b49ebffdb09b5cdfd4ea091e8dba76672fc56f Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 19 Feb 2017 13:50:28 +0100 Subject: [PATCH 054/159] client/paging: fix mass tag double binding Fixes #125 --- client/js/views/posts_header_view.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/js/views/posts_header_view.js b/client/js/views/posts_header_view.js index efa3d02c..3f70ef5a 100644 --- a/client/js/views/posts_header_view.js +++ b/client/js/views/posts_header_view.js @@ -150,7 +150,7 @@ class PostsHeaderView extends events.EventTarget { } for (let editor of this._bulkEditors) { - this._bulkTagEditor.addEventListener('submit', e => { + editor.addEventListener('submit', e => { this._navigate(); }); editor.addEventListener('open', e => { From 5dfdfd49e9f812844560c4185abd568d14df86c8 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 19 Feb 2017 13:33:13 +0100 Subject: [PATCH 055/159] client/paging: fix loading on small page sizes Fixes #126 --- client/html/endless_pager.tpl | 2 + client/js/views/endless_page_view.js | 90 ++++++++++++++++++---------- 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/client/html/endless_pager.tpl b/client/html/endless_pager.tpl index 6870f9a5..4812f96d 100644 --- a/client/html/endless_pager.tpl +++ b/client/html/endless_pager.tpl @@ -1,5 +1,7 @@
    +
    +
    diff --git a/client/js/views/endless_page_view.js b/client/js/views/endless_page_view.js index 9ed82fe4..93fb5200 100644 --- a/client/js/views/endless_page_view.js +++ b/client/js/views/endless_page_view.js @@ -6,6 +6,13 @@ const views = require('../util/views.js'); const holderTemplate = views.getTemplate('endless-pager'); const pageTemplate = views.getTemplate('endless-pager-page'); +function isScrolledIntoView(el) { + const elemTop = el.getBoundingClientRect().top; + const elemBottom = el.getBoundingClientRect().bottom; + const isVisible = (elemTop >= 0) && (elemBottom <= window.innerHeight); + return isVisible; +} + class EndlessPageView { constructor(ctx) { this._hostNode = document.getElementById('content-holder'); @@ -13,28 +20,35 @@ class EndlessPageView { } run(ctx) { + this._destroy(); + this._active = true; - this._working = 0; - this._init = false; + this._runningRequests = 0; + this._initialPageLoad = true; this.clearMessages(); views.emptyContent(this._pagesHolderNode); - this.threshold = window.innerHeight / 3; this.minOffsetShown = null; this.maxOffsetShown = null; this.totalRecords = null; this.currentOffset = 0; + this.defaultLimit = parseInt(ctx.parameters.limit || ctx.defaultLimit); - const offset = parseInt(ctx.parameters.offset || 0); - const limit = parseInt(ctx.parameters.limit || ctx.defaultLimit); - this._loadPage(ctx, offset, limit, true) + const initialOffset = parseInt(ctx.parameters.offset || 0); + this._loadPage(ctx, initialOffset, this.defaultLimit, true) .then(pageNode => { - if (offset !== 0) { + if (initialOffset !== 0) { pageNode.scrollIntoView(); } }); - this._probePageLoad(ctx); + + this._timeout = window.setInterval(() => { + window.requestAnimationFrame(() => { + this._probePageLoad(ctx); + this._syncUrl(ctx); + }); + }, 250); views.monitorNodeRemoval(this._pagesHolderNode, () => this._destroy()); } @@ -43,27 +57,24 @@ class EndlessPageView { return this._hostNode.querySelector('.page-header-holder'); } + get topPageGuardNode() { + return this._hostNode.querySelector('.page-guard.top'); + } + + get bottomPageGuardNode() { + return this._hostNode.querySelector('.page-guard.bottom'); + } + get _pagesHolderNode() { return this._hostNode.querySelector('.pages-holder'); } _destroy() { + window.clearInterval(this._timeout); this._active = false; } - _probePageLoad(ctx) { - if (this._active) { - window.setTimeout(() => { - window.requestAnimationFrame(() => { - this._probePageLoad(ctx); - }); - }, 250); - } - - if (this._working) { - return; - } - + _syncUrl(ctx) { let topPageNode = null; let element = document.elementFromPoint( window.innerWidth / 2, @@ -89,6 +100,12 @@ class EndlessPageView { false); this.currentOffset = topOffset; } + } + + _probePageLoad(ctx) { + if (!this._active || this._runningRequests) { + return; + } if (this.totalRecords === null) { return; @@ -97,32 +114,41 @@ class EndlessPageView { document.documentElement.scrollHeight - document.documentElement.clientHeight; - if (this.minOffsetShown > 0 && window.scrollY < this.threshold) { + if (this.minOffsetShown > 0 && + isScrolledIntoView(this.topPageGuardNode)) { this._loadPage( - ctx, this.minOffsetShown - topLimit, topLimit, false); - } else if (this.maxOffsetShown < this.totalRecords && - window.scrollY + this.threshold > scrollHeight) { + ctx, + this.minOffsetShown - this.defaultLimit, + this.defaultLimit, + false); + } + + if (this.maxOffsetShown < this.totalRecords && + isScrolledIntoView(this.bottomPageGuardNode)) { this._loadPage( - ctx, this.maxOffsetShown, topLimit, true); + ctx, + this.maxOffsetShown, + this.defaultLimit, + true); } } _loadPage(ctx, offset, limit, append) { - this._working++; + this._runningRequests++; return new Promise((resolve, reject) => { ctx.requestPage(offset, limit).then(response => { if (!this._active) { - this._working--; + this._runningRequests--; return Promise.reject(); } window.requestAnimationFrame(() => { let pageNode = this._renderPage(ctx, append, response); - this._working--; + this._runningRequests--; resolve(pageNode); }); }, error => { this.showError(error.message); - this._working--; + this._runningRequests--; reject(); }); }); @@ -165,7 +191,7 @@ class EndlessPageView { if (append) { this._pagesHolderNode.appendChild(pageNode); - if (!this._init && response.offset > 0) { + if (this._initialPageLoad && response.offset > 0) { window.scroll(0, pageNode.getBoundingClientRect().top); } } else { @@ -179,7 +205,7 @@ class EndlessPageView { this.showInfo('No data to show'); } - this._init = true; + this._initialPageLoad = false; return pageNode; } From 34366b72fbc8583c8eb8478dde0d65c4b9321b18 Mon Sep 17 00:00:00 2001 From: rr- Date: Tue, 21 Feb 2017 18:29:32 +0100 Subject: [PATCH 056/159] client/file-dropper: add ability to lock URLs --- client/html/file_dropper.tpl | 6 +++++- client/js/controls/file_dropper_control.js | 25 ++++++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/client/html/file_dropper.tpl b/client/html/file_dropper.tpl index 9e7715d5..2ce0ff51 100644 --- a/client/html/file_dropper.tpl +++ b/client/html/file_dropper.tpl @@ -15,6 +15,10 @@ <% if (ctx.allowUrls) { %> - + <% if (ctx.lock) { %> + + <% } else { %> + + <% } %> <% } %>
    diff --git a/client/js/controls/file_dropper_control.js b/client/js/controls/file_dropper_control.js index 7b3c1209..407fa9d3 100644 --- a/client/js/controls/file_dropper_control.js +++ b/client/js/controls/file_dropper_control.js @@ -5,6 +5,8 @@ const views = require('../util/views.js'); const template = views.getTemplate('file-dropper'); +const KEY_RETURN = 13; + class FileDropperControl extends events.EventTarget { constructor(target, options) { super(); @@ -14,6 +16,7 @@ class FileDropperControl extends events.EventTarget { extraText: options.extraText, allowMultiple: options.allowMultiple, allowUrls: options.allowUrls, + lock: options.lock, id: 'file-' + Math.random().toString(36).substring(7), }); @@ -37,8 +40,12 @@ class FileDropperControl extends events.EventTarget { 'change', e => this._evtFileChange(e)); if (this._urlInputNode) { + this._urlInputNode.addEventListener( + 'keydown', e => this._evtUrlInputKeyDown(e)); + } + if (this._urlConfirmButtonNode) { this._urlConfirmButtonNode.addEventListener( - 'click', e => this._evtUrlConfirm(e)); + 'click', e => this._evtUrlConfirmButtonClick(e)); } this._originalHtml = this._dropperNode.innerHTML; @@ -62,6 +69,10 @@ class FileDropperControl extends events.EventTarget { _emitUrls(urls) { urls = Array.from(urls).map(url => url.trim()); + if (this._options.lock) { + this._dropperNode.innerText = + urls.map(url => url.split(/\//).reverse()[0]).join(', '); + } for (let url of urls) { if (!url) { return; @@ -106,7 +117,17 @@ class FileDropperControl extends events.EventTarget { this._emitFiles(e.dataTransfer.files); } - _evtUrlConfirm(e) { + _evtUrlInputKeyDown(e) { + if (e.which !== KEY_RETURN) { + return; + } + e.preventDefault(); + this._dropperNode.classList.remove('active'); + this._emitUrls(this._urlInputNode.value.split(/[\r\n]/)); + this._urlInputNode.value = ''; + } + + _evtUrlConfirmButtonClick(e) { e.preventDefault(); this._dropperNode.classList.remove('active'); this._emitUrls(this._urlInputNode.value.split(/[\r\n]/)); From b27855523a5dbd0a59db5eb2e60c0791ec9d4ebe Mon Sep 17 00:00:00 2001 From: rr- Date: Tue, 21 Feb 2017 18:29:51 +0100 Subject: [PATCH 057/159] client/file-dropper: fix drawing long URLs --- client/css/core-forms.styl | 1 + 1 file changed, 1 insertion(+) diff --git a/client/css/core-forms.styl b/client/css/core-forms.styl index 06a54421..58002263 100644 --- a/client/css/core-forms.styl +++ b/client/css/core-forms.styl @@ -256,6 +256,7 @@ input::-moz-focus-inner line-height: 140% text-align: center cursor: pointer + overflow: hidden word-wrap: break-word input margin-top: 0.5em From 1e58899b03319c56682dbc562cbf5828b97bddf3 Mon Sep 17 00:00:00 2001 From: rr- Date: Tue, 21 Feb 2017 18:30:19 +0100 Subject: [PATCH 058/159] client/posts: allow updating content from URL --- client/js/controls/post_edit_sidebar_control.js | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client/js/controls/post_edit_sidebar_control.js b/client/js/controls/post_edit_sidebar_control.js index 70570d75..f1cfe934 100644 --- a/client/js/controls/post_edit_sidebar_control.js +++ b/client/js/controls/post_edit_sidebar_control.js @@ -72,10 +72,13 @@ class PostEditSidebarControl extends events.EventTarget { if (this._contentInputNode) { this._contentFileDropper = new FileDropperControl( - this._contentInputNode, {lock: true}); + this._contentInputNode, {allowUrls: true, lock: true}); this._contentFileDropper.addEventListener('fileadd', e => { this._newPostContent = e.detail.files[0]; }); + this._contentFileDropper.addEventListener('urladd', e => { + this._newPostContent = e.detail.urls[0]; + }); } if (this._thumbnailInputNode) { From d00d282bff15d37cff7818d39663ad07b3cbee43 Mon Sep 17 00:00:00 2001 From: rr- Date: Tue, 21 Feb 2017 18:54:46 +0100 Subject: [PATCH 059/159] client/posts: improve file dropper appearance --- client/css/core-forms.styl | 19 +++++++++---------- client/css/post-main-view.styl | 2 +- client/html/file_dropper.tpl | 14 ++++++++------ 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/client/css/core-forms.styl b/client/css/core-forms.styl index 58002263..a163718b 100644 --- a/client/css/core-forms.styl +++ b/client/css/core-forms.styl @@ -245,11 +245,8 @@ input::-moz-focus-inner * File dropper */ .file-dropper-holder - display: flex - flex-wrap: wrap .file-dropper display: block - width: 100% background: $window-color border: 3px dashed #eee padding: 0.3em 0.5em @@ -258,14 +255,16 @@ input::-moz-focus-inner cursor: pointer overflow: hidden word-wrap: break-word - input + .url-holder + display: flex margin-top: 0.5em - margin-right: 0.5em - width: auto - flex: 1 - button - margin-top: 0.5em - width: 8em + input, button + min-width: 0 /* firefox being sassy */ + width: auto !important /* don't inherit anything weird */ + input + flex: 1 + button + margin-left: 0.5em input[type=file]:disabled+.file-dropper cursor: default diff --git a/client/css/post-main-view.styl b/client/css/post-main-view.styl index b4d9b8fe..cfebf591 100644 --- a/client/css/post-main-view.styl +++ b/client/css/post-main-view.styl @@ -138,7 +138,7 @@ margin: 0 padding: 0 - label + label:not(.file-dropper) margin-bottom: 0.3em display: block diff --git a/client/html/file_dropper.tpl b/client/html/file_dropper.tpl index 2ce0ff51..3da2f4f8 100644 --- a/client/html/file_dropper.tpl +++ b/client/html/file_dropper.tpl @@ -14,11 +14,13 @@ <% } %> <% if (ctx.allowUrls) { %> - - <% if (ctx.lock) { %> - - <% } else { %> - - <% } %> +
    + + <% if (ctx.lock) { %> + + <% } else { %> + + <% } %> +
    <% } %> From 5467ca6b7e259f489927c51ad5ab00ef4835e1c2 Mon Sep 17 00:00:00 2001 From: rr- Date: Tue, 21 Feb 2017 19:09:18 +0100 Subject: [PATCH 060/159] client/posts: improve placeholder in file dropper The default one was too long to fit in the sidebar --- client/html/file_dropper.tpl | 2 +- client/js/controls/file_dropper_control.js | 2 ++ client/js/controls/post_edit_sidebar_control.js | 5 ++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/client/html/file_dropper.tpl b/client/html/file_dropper.tpl index 3da2f4f8..3123cd01 100644 --- a/client/html/file_dropper.tpl +++ b/client/html/file_dropper.tpl @@ -15,7 +15,7 @@ <% if (ctx.allowUrls) { %>
    - + <% if (ctx.lock) { %> <% } else { %> diff --git a/client/js/controls/file_dropper_control.js b/client/js/controls/file_dropper_control.js index 407fa9d3..f725ee01 100644 --- a/client/js/controls/file_dropper_control.js +++ b/client/js/controls/file_dropper_control.js @@ -18,6 +18,8 @@ class FileDropperControl extends events.EventTarget { allowUrls: options.allowUrls, lock: options.lock, id: 'file-' + Math.random().toString(36).substring(7), + urlPlaceholder: + options.urlPlaceholder || 'Alternatively, paste an URL here.', }); this._dropperNode = source.querySelector('.file-dropper'); diff --git a/client/js/controls/post_edit_sidebar_control.js b/client/js/controls/post_edit_sidebar_control.js index f1cfe934..3da2fa41 100644 --- a/client/js/controls/post_edit_sidebar_control.js +++ b/client/js/controls/post_edit_sidebar_control.js @@ -72,7 +72,10 @@ class PostEditSidebarControl extends events.EventTarget { if (this._contentInputNode) { this._contentFileDropper = new FileDropperControl( - this._contentInputNode, {allowUrls: true, lock: true}); + this._contentInputNode, { + allowUrls: true, + lock: true, + urlPlaceholder: '...or paste an URL here.'}); this._contentFileDropper.addEventListener('fileadd', e => { this._newPostContent = e.detail.files[0]; }); From 87b3572ce594636ea57302cbe4f93c2d50ff39c0 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 26 Feb 2017 12:57:24 +0100 Subject: [PATCH 061/159] client/paging: fix endless scroll on android --- client/js/views/endless_page_view.js | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/client/js/views/endless_page_view.js b/client/js/views/endless_page_view.js index 93fb5200..ab406623 100644 --- a/client/js/views/endless_page_view.js +++ b/client/js/views/endless_page_view.js @@ -6,11 +6,15 @@ const views = require('../util/views.js'); const holderTemplate = views.getTemplate('endless-pager'); const pageTemplate = views.getTemplate('endless-pager-page'); -function isScrolledIntoView(el) { - const elemTop = el.getBoundingClientRect().top; - const elemBottom = el.getBoundingClientRect().bottom; - const isVisible = (elemTop >= 0) && (elemBottom <= window.innerHeight); - return isVisible; +function isScrolledIntoView(element) { + let top = 0; + do { + top += element.offsetTop || 0; + element = element.offsetParent; + } while(element); + return ( + (top >= window.scrollY) && + (top <= window.scrollY + window.innerHeight)); } class EndlessPageView { @@ -110,9 +114,6 @@ class EndlessPageView { if (this.totalRecords === null) { return; } - let scrollHeight = - document.documentElement.scrollHeight - - document.documentElement.clientHeight; if (this.minOffsetShown > 0 && isScrolledIntoView(this.topPageGuardNode)) { From e087b83082703c54cad919c961a95cadc8e19cfa Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 26 Feb 2017 18:47:53 +0100 Subject: [PATCH 062/159] client/notes: don't rely on class names The state names, used by CSS, were being broken by the minifier. --- .../js/controls/post_notes_overlay_control.js | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/client/js/controls/post_notes_overlay_control.js b/client/js/controls/post_notes_overlay_control.js index 43e8bf69..bbea0956 100644 --- a/client/js/controls/post_notes_overlay_control.js +++ b/client/js/controls/post_notes_overlay_control.js @@ -72,10 +72,8 @@ function _getNoteSize(note) { } class State { - constructor(control) { + constructor(control, stateName) { this._control = control; - const stateName = misc.decamelize( - this.constructor.name.replace(/State/, '')); _setNodeState(control._hostNode, stateName); _setNodeState(control._textNode, stateName); } @@ -132,7 +130,7 @@ class State { class ReadOnlyState extends State { constructor(control) { - super(control); + super(control, 'read-only'); if (_clearEditedNote(control._hostNode)) { this._control.dispatchEvent(new CustomEvent('blur')); } @@ -146,7 +144,7 @@ class ReadOnlyState extends State { class PassiveState extends State { constructor(control) { - super(control); + super(control, 'passive'); if (_clearEditedNote(control._hostNode)) { this._control.dispatchEvent(new CustomEvent('blur')); } @@ -163,13 +161,13 @@ class PassiveState extends State { } class ActiveState extends State { - constructor(control, note) { - super(control); + constructor(control, note, stateName) { + super(control, stateName); if (_clearEditedNote(control._hostNode)) { this._control.dispatchEvent(new CustomEvent('blur')); } keyboard.pause(); - if (note !== undefined) { + if (note !== null) { this._note = note; this._control.dispatchEvent( new CustomEvent('focus', { @@ -182,7 +180,7 @@ class ActiveState extends State { class SelectedState extends ActiveState { constructor(control, note) { - super(control, note); + super(control, note, 'selected'); this._clickTimeout = null; this._control._hideNoteText(); } @@ -299,7 +297,7 @@ class SelectedState extends ActiveState { class MovingPointState extends ActiveState { constructor(control, note, notePoint, mousePoint) { - super(control, note); + super(control, note, 'moving-point'); this._notePoint = notePoint; this._originalNotePoint = {x: notePoint.x, y: notePoint.y}; this._originalPosition = mousePoint; @@ -328,7 +326,7 @@ class MovingPointState extends ActiveState { class MovingNoteState extends ActiveState { constructor(control, note, mousePoint) { - super(control, note); + super(control, note, 'moving-note'); this._originalPolygon = [...note.polygon].map( point => ({x: point.x, y: point.y})); this._originalPosition = mousePoint; @@ -360,7 +358,7 @@ class MovingNoteState extends ActiveState { class ScalingNoteState extends ActiveState { constructor(control, note, mousePoint) { - super(control, note); + super(control, note, 'scaling-note'); this._originalPolygon = [...note.polygon].map( point => ({x: point.x, y: point.y})); this._originalMousePoint = mousePoint; @@ -402,7 +400,7 @@ class ScalingNoteState extends ActiveState { class ReadyToDrawState extends ActiveState { constructor(control) { - super(control); + super(control, null, 'ready-to-draw'); } evtNoteMouseDown(e, hoveredNote) { @@ -423,7 +421,7 @@ class ReadyToDrawState extends ActiveState { class DrawingRectangleState extends ActiveState { constructor(control, mousePoint) { - super(control); + super(control, null, 'drawing-rectangle'); this._note = this._createNote(); this._note.polygon.add(new Point(mousePoint.x, mousePoint.y)); this._note.polygon.add(new Point(mousePoint.x, mousePoint.y)); @@ -460,7 +458,7 @@ class DrawingRectangleState extends ActiveState { class DrawingPolygonState extends ActiveState { constructor(control, mousePoint) { - super(control); + super(control, null, 'drawing-polygon'); this._note = this._createNote(); this._note.polygon.add(new Point(mousePoint.x, mousePoint.y)); this._note.polygon.add(new Point(mousePoint.x, mousePoint.y)); From 5681fd11efe93b5b2a1719f16c9f253ffd508842 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 3 Mar 2017 17:24:58 +0100 Subject: [PATCH 063/159] server/net: make the user-agent configurable Fixes #127 --- config.yaml.dist | 1 + server/szurubooru/func/net.py | 3 +++ server/szurubooru/tests/func/test_net.py | 5 ++++- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/config.yaml.dist b/config.yaml.dist index 3b00d0eb..d1dc97c5 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -10,6 +10,7 @@ api_url: # where frontend connects to, example: http://api.example.com/ base_url: # used to form links to frontend, example: http://example.com/ data_url: # used to form links to posts and avatars, example: http://example.com/data/ data_dir: # absolute path for posts and avatars storage, example: /srv/www/booru/client/public/data/ +user_agent: # user agent name used to download files from the web on behalf of the api users # usage: schema://user:password@host:port/database_name diff --git a/server/szurubooru/func/net.py b/server/szurubooru/func/net.py index a6e18214..e6326c06 100644 --- a/server/szurubooru/func/net.py +++ b/server/szurubooru/func/net.py @@ -1,10 +1,13 @@ import urllib.request +from szurubooru import config from szurubooru import errors def download(url: str) -> bytes: assert url request = urllib.request.Request(url) + if config.config['user_agent']: + request.add_header('User-Agent', config.config['user_agent']) request.add_header('Referer', url) try: with urllib.request.urlopen(request) as handle: diff --git a/server/szurubooru/tests/func/test_net.py b/server/szurubooru/tests/func/test_net.py index f749d384..fb149b05 100644 --- a/server/szurubooru/tests/func/test_net.py +++ b/server/szurubooru/tests/func/test_net.py @@ -1,7 +1,10 @@ from szurubooru.func import net -def test_download(): +def test_download(config_injector): + config_injector({ + 'user_agent': None + }) url = 'http://info.cern.ch/hypertext/WWW/TheProject.html' expected_content = ( From 49feb932f35fbc5bdb6b09eb3f881a7c9df1e3ab Mon Sep 17 00:00:00 2001 From: rr- Date: Sat, 4 Mar 2017 16:55:53 +0100 Subject: [PATCH 064/159] client/tags: merging can now also add aliases --- client/html/tag_merge.tpl | 6 ++++-- client/js/controllers/tag_controller.js | 23 +++++++++++++---------- client/js/models/tag.js | 10 +++++++++- client/js/views/tag_merge_view.js | 7 ++++++- 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/client/html/tag_merge.tpl b/client/html/tag_merge.tpl index 90d8aa63..0ffbdd2b 100644 --- a/client/html/tag_merge.tpl +++ b/client/html/tag_merge.tpl @@ -2,12 +2,14 @@
    diff --git a/client/html/posts_page.tpl b/client/html/posts_page.tpl index 895de559..6f0665c9 100644 --- a/client/html/posts_page.tpl +++ b/client/html/posts_page.tpl @@ -4,7 +4,7 @@ <% for (let post of ctx.response.results) { %>
  • ' + title='@<%- post.id %> (<%- post.type %>) Tags: <%- post.tags.map(tag => '#' + tag.names[0]).join(' ') || 'none' %>' href='<%= ctx.canViewPosts ? ctx.getPostUrl(post.id, ctx.parameters) : '' %>'> <%= ctx.makeThumbnail(post.thumbnailUrl) %> diff --git a/client/html/tag_edit.tpl b/client/html/tag_edit.tpl index a58c529e..49842852 100644 --- a/client/html/tag_edit.tpl +++ b/client/html/tag_edit.tpl @@ -22,18 +22,12 @@
  • <% if (ctx.canEditImplications) { %> - <%= ctx.makeTextInput({ - text: 'Implications', - value: ctx.tag.implications.join(' '), - }) %> + <%= ctx.makeTextInput({text: 'Implications'}) %> <% } %>
  • <% if (ctx.canEditSuggestions) { %> - <%= ctx.makeTextInput({ - text: 'Suggestions', - value: ctx.tag.suggestions.join(' '), - }) %> + <%= ctx.makeTextInput({text: 'Suggestions'}) %> <% } %>
  • diff --git a/client/html/tag_summary.tpl b/client/html/tag_summary.tpl index 0513d643..ad000841 100644 --- a/client/html/tag_summary.tpl +++ b/client/html/tag_summary.tpl @@ -9,7 +9,7 @@ Aliases:
      <% for (let name of ctx.tag.names.slice(1)) { %>
    • <%= ctx.makeTagLink(name) %>
    • <%= ctx.makeTagLink(name, false, false, ctx.tag) %>
    • <% } %>
    @@ -18,7 +18,7 @@ Implications:
      <% for (let tag of ctx.tag.implications) { %>
    • <%= ctx.makeTagLink(tag) %>
    • <%= ctx.makeTagLink(tag.names[0], false, false, tag) %>
    • <% } %>
    @@ -27,7 +27,7 @@ Suggestions:
      <% for (let tag of ctx.tag.suggestions) { %>
    • <%= ctx.makeTagLink(tag) %>
    • <%= ctx.makeTagLink(tag.names[0], false, false, tag) %>
    • <% } %>
    diff --git a/client/html/tags_page.tpl b/client/html/tags_page.tpl index 788d0e36..be3b143a 100644 --- a/client/html/tags_page.tpl +++ b/client/html/tags_page.tpl @@ -44,15 +44,15 @@
      <% for (let name of tag.names) { %> -
    • <%= ctx.makeTagLink(name) %>
    • +
    • <%= ctx.makeTagLink(name, false, false, tag) %>
    • <% } %>
    <% if (tag.implications.length) { %>
      - <% for (let name of tag.implications) { %> -
    • <%= ctx.makeTagLink(name) %>
    • + <% for (let relation of tag.implications) { %> +
    • <%= ctx.makeTagLink(relation.names[0], false, false, relation) %>
    • <% } %>
    <% } else { %> @@ -62,8 +62,8 @@ <% if (tag.suggestions.length) { %>
      - <% for (let name of tag.suggestions) { %> -
    • <%= ctx.makeTagLink(name) %>
    • + <% for (let relation of tag.suggestions) { %> +
    • <%= ctx.makeTagLink(relation.names[0], false, false, relation) %>
    • <% } %>
    <% } else { %> diff --git a/client/js/controllers/post_list_controller.js b/client/js/controllers/post_list_controller.js index 4fce8ea1..039ca187 100644 --- a/client/js/controllers/post_list_controller.js +++ b/client/js/controllers/post_list_controller.js @@ -62,15 +62,16 @@ class PostListController { } _evtTag(e) { - for (let tag of this._bulkEditTags) { - e.detail.post.addTag(tag); - } - e.detail.post.save().catch(error => window.alert(error.message)); + Promise.all( + this._bulkEditTags.map(tag => + e.detail.post.tags.addByName(tag))) + .then(() => { e.detail.post.save(); }) + .catch(error => window.alert(error.message)); } _evtUntag(e) { for (let tag of this._bulkEditTags) { - e.detail.post.removeTag(tag); + e.detail.post.tags.removeByName(tag); } e.detail.post.save().catch(error => window.alert(error.message)); } diff --git a/client/js/controllers/post_main_controller.js b/client/js/controllers/post_main_controller.js index bb1fd594..a482b844 100644 --- a/client/js/controllers/post_main_controller.js +++ b/client/js/controllers/post_main_controller.js @@ -132,9 +132,6 @@ class PostMainController extends BasePostController { this._view.sidebarControl.disableForm(); this._view.sidebarControl.clearMessages(); const post = e.detail.post; - if (e.detail.tags !== undefined) { - post.tags = e.detail.tags; - } if (e.detail.safety !== undefined) { post.safety = e.detail.safety; } diff --git a/client/js/controllers/post_upload_controller.js b/client/js/controllers/post_upload_controller.js index 3c326419..9dfdb4af 100644 --- a/client/js/controllers/post_upload_controller.js +++ b/client/js/controllers/post_upload_controller.js @@ -8,6 +8,7 @@ const misc = require('../util/misc.js'); const progress = require('../util/progress.js'); const topNavigation = require('../models/top_navigation.js'); const Post = require('../models/post.js'); +const Tag = require('../models/tag.js'); const PostUploadView = require('../views/post_upload_view.js'); const EmptyView = require('../views/empty_view.js'); @@ -144,7 +145,11 @@ class PostUploadController { let post = new Post(); post.safety = uploadable.safety; post.flags = uploadable.flags; - post.tags = uploadable.tags; + for (let tagName of uploadable.tags) { + const tag = new Tag(); + tag.names = [tagName]; + post.tags.add(tag); + } post.relations = uploadable.relations; post.newContent = uploadable.url || uploadable.file; return post; diff --git a/client/js/controllers/tag_categories_controller.js b/client/js/controllers/tag_categories_controller.js index dadf8e4a..49600cf3 100644 --- a/client/js/controllers/tag_categories_controller.js +++ b/client/js/controllers/tag_categories_controller.js @@ -40,7 +40,7 @@ class TagCategoriesController { this._view.disableForm(); this._tagCategories.save() .then(() => { - tags.refreshExport(); + tags.refreshCategoryColorMap(); this._view.enableForm(); this._view.showSuccess('Changes saved.'); }, error => { diff --git a/client/js/controllers/tag_controller.js b/client/js/controllers/tag_controller.js index d33e3d72..e5908405 100644 --- a/client/js/controllers/tag_controller.js +++ b/client/js/controllers/tag_controller.js @@ -4,8 +4,8 @@ const router = require('../router.js'); const api = require('../api.js'); const misc = require('../util/misc.js'); const uri = require('../util/uri.js'); -const tags = require('../tags.js'); const Tag = require('../models/tag.js'); +const TagCategoryList = require('../models/tag_category_list.js'); const topNavigation = require('../models/top_navigation.js'); const TagView = require('../views/tag_view.js'); const EmptyView = require('../views/empty_view.js'); @@ -18,7 +18,12 @@ class TagController { return; } - Tag.get(ctx.parameters.name).then(tag => { + Promise.all([ + TagCategoryList.get(), + Tag.get(ctx.parameters.name), + ]).then(responses => { + const [tagCategoriesResponse, tag] = responses; + topNavigation.activate('tags'); topNavigation.setTitle('Tag #' + tag.names[0]); @@ -26,7 +31,7 @@ class TagController { tag.addEventListener('change', e => this._evtSaved(e, section)); const categories = {}; - for (let category of tags.getAllCategories()) { + for (let category of tagCategoriesResponse.results) { categories[category.name] = category.name; } @@ -76,12 +81,6 @@ class TagController { if (e.detail.category !== undefined) { e.detail.tag.category = e.detail.category; } - if (e.detail.implications !== undefined) { - e.detail.tag.implications = e.detail.implications; - } - if (e.detail.suggestions !== undefined) { - e.detail.tag.suggestions = e.detail.suggestions; - } if (e.detail.description !== undefined) { e.detail.tag.description = e.detail.description; } diff --git a/client/js/controllers/tag_list_controller.js b/client/js/controllers/tag_list_controller.js index 3b4bd3e7..8bc7dbba 100644 --- a/client/js/controllers/tag_list_controller.js +++ b/client/js/controllers/tag_list_controller.js @@ -11,7 +11,12 @@ const TagsPageView = require('../views/tags_page_view.js'); const EmptyView = require('../views/empty_view.js'); const fields = [ - 'names', 'suggestions', 'implications', 'creationTime', 'usages']; + 'names', + 'suggestions', + 'implications', + 'creationTime', + 'usages', + 'category']; class TagListController { constructor(ctx) { diff --git a/client/js/controls/auto_complete_control.js b/client/js/controls/auto_complete_control.js index 3f46ab5b..5ce3bb2e 100644 --- a/client/js/controls/auto_complete_control.js +++ b/client/js/controls/auto_complete_control.js @@ -28,10 +28,7 @@ class AutoCompleteControl { this._sourceInputNode = sourceInputNode; this._options = {}; Object.assign(this._options, { - transform: null, verticalShift: 2, - source: null, - addSpace: false, maxResults: 15, getTextToFind: () => { const value = sourceInputNode.value; @@ -56,7 +53,7 @@ class AutoCompleteControl { this._isVisible = false; } - defaultConfirmStrategy(text) { + replaceSelectedText(result, addSpace) { const start = _getSelectionStart(this._sourceInputNode); let prefix = ''; let suffix = this._sourceInputNode.value.substring(start); @@ -66,30 +63,25 @@ class AutoCompleteControl { prefix = this._sourceInputNode.value.substring(0, index + 1); middle = this._sourceInputNode.value.substring(index + 1); } - this._sourceInputNode.value = prefix + text + ' ' + suffix.trimLeft(); - if (!this._options.addSpace) { + this._sourceInputNode.value = ( + prefix + result.toString() + ' ' + suffix.trimLeft()); + if (!addSpace) { this._sourceInputNode.value = this._sourceInputNode.value.trim(); } this._sourceInputNode.focus(); } - _delete(text) { - if (this._options.transform) { - text = this._options.transform(text); - } + _delete(result) { if (this._options.delete) { - this._options.delete(text); + this._options.delete(result); } } - _confirm(text) { - if (this._options.transform) { - text = this._options.transform(text); - } + _confirm(result) { if (this._options.confirm) { - this._options.confirm(text); + this._options.confirm(result); } else { - this.defaultConfirmStrategy(text); + this.defaultConfirmStrategy(result); } } @@ -104,7 +96,6 @@ class AutoCompleteControl { this.hide(); } else { this._updateResults(textToFind); - this._refreshList(); } } @@ -209,15 +200,16 @@ class AutoCompleteControl { } _updateResults(textToFind) { - const oldResults = this._results.slice(); - this._results = - this._options.getMatches(textToFind) - .slice(0, this._options.maxResults); - const oldResultsHash = JSON.stringify(oldResults); - const newResultsHash = JSON.stringify(this._results); - if (oldResultsHash !== newResultsHash) { - this._activeResult = -1; - } + this._options.getMatches(textToFind).then(matches => { + const oldResults = this._results.slice(); + this._results = matches.slice(0, this._options.maxResults); + const oldResultsHash = JSON.stringify(oldResults); + const newResultsHash = JSON.stringify(this._results); + if (oldResultsHash !== newResultsHash) { + this._activeResult = -1; + } + this._refreshList(); + }); } _refreshList() { diff --git a/client/js/controls/post_edit_sidebar_control.js b/client/js/controls/post_edit_sidebar_control.js index 72cabc9c..38dded28 100644 --- a/client/js/controls/post_edit_sidebar_control.js +++ b/client/js/controls/post_edit_sidebar_control.js @@ -72,7 +72,8 @@ class PostEditSidebarControl extends events.EventTarget { } if (this._tagInputNode) { - this._tagControl = new TagInputControl(this._tagInputNode); + this._tagControl = new TagInputControl( + this._tagInputNode, post.tags); } if (this._contentInputNode) { @@ -171,10 +172,11 @@ class PostEditSidebarControl extends events.EventTarget { }); } - this._tagControl.addEventListener('change', e => { - this._post.tags = this._tagControl.tags; - this._syncExpanderTitles(); - }); + this._tagControl.addEventListener( + 'change', e => { + this.dispatchEvent(new CustomEvent('change')); + this._syncExpanderTitles(); + }); if (this._noteTextareaNode) { this._noteTextareaNode.addEventListener( diff --git a/client/js/controls/post_readonly_sidebar_control.js b/client/js/controls/post_readonly_sidebar_control.js index 30edd410..bb51e69d 100644 --- a/client/js/controls/post_readonly_sidebar_control.js +++ b/client/js/controls/post_readonly_sidebar_control.js @@ -3,7 +3,6 @@ const api = require('../api.js'); const config = require('../config.js'); const events = require('../events.js'); -const tags = require('../tags.js'); const views = require('../util/views.js'); const template = views.getTemplate('post-readonly-sidebar'); @@ -22,8 +21,6 @@ class PostReadonlySidebarControl extends events.EventTarget { views.replaceContent(this._hostNode, template({ post: this._post, - getTagCategory: this._getTagCategory, - getTagUsages: this._getTagUsages, enableSafety: config.enableSafety, canListPosts: api.hasPrivilege('posts:list'), canEditPosts: api.hasPrivilege('posts:edit'), @@ -161,16 +158,6 @@ class PostReadonlySidebarControl extends events.EventTarget { newNode.classList.add('active'); } - _getTagUsages(name) { - const tag = tags.getTagByName(name); - return tag ? tag.usages : 0; - } - - _getTagCategory(name) { - const tag = tags.getTagByName(name); - return tag ? tag.category : 'unknown'; - } - _evtAddToFavoritesClick(e) { e.preventDefault(); this.dispatchEvent(new CustomEvent('favorite', { diff --git a/client/js/controls/tag_auto_complete_control.js b/client/js/controls/tag_auto_complete_control.js index 47434c7d..b7dda201 100644 --- a/client/js/controls/tag_auto_complete_control.js +++ b/client/js/controls/tag_auto_complete_control.js @@ -1,9 +1,33 @@ 'use strict'; -const tags = require('../tags.js'); const misc = require('../util/misc.js'); +const views = require('../util/views.js'); +const TagList = require('../models/tag_list.js'); const AutoCompleteControl = require('./auto_complete_control.js'); +function _escapeSearch(text) { + return text.replace('\\', '\\\\').replace(':', '\\:'); +} + +function _tagListToMatches(tags, options) { + return [...tags].sort((tag1, tag2) => { + return tag2.usages - tag1.usages; + }).map(tag => { + let cssName = misc.makeCssName(tag.category, 'tag'); + if (options.isTaggedWith(tag.names[0])) { + cssName += ' disabled'; + } + const caption = ( + '' + + misc.escapeHtml(tag.names[0] + ' (' + tag.postCount + ')') + + ''); + return { + caption: caption, + value: tag, + }; + }); +} + class TagAutoCompleteControl extends AutoCompleteControl { constructor(input, options) { const minLengthForPartialSearch = 3; @@ -13,32 +37,21 @@ class TagAutoCompleteControl extends AutoCompleteControl { }, options); options.getMatches = text => { - const transform = x => x.toLowerCase(); - const match = text.length < minLengthForPartialSearch ? - (a, b) => a.startsWith(b) : - (a, b) => a.includes(b); - text = transform(text); - return Array.from(tags.getNameToTagMap().entries()) - .filter(kv => match(transform(kv[0]), text)) - .sort((kv1, kv2) => { - return kv2[1].usages - kv1[1].usages; - }) - .map(kv => { - const origName = tags.getOriginalTagName(kv[0]); - const category = kv[1].category; - const usages = kv[1].usages; - let cssName = misc.makeCssName(category, 'tag'); - if (options.isTaggedWith(kv[0])) { - cssName += ' disabled'; - } - return { - caption: misc.unindent` - - ${misc.escapeHtml(origName)} (${usages}) - `, - value: origName, - }; - }); + const term = misc.escapeSearchTerm(text); + const query = ( + text.length < minLengthForPartialSearch + ? term + '*' + : '*' + term + '*') + ' sort:usages'; + + return new Promise((resolve, reject) => { + TagList.search( + query, 0, this._options.maxResults, + ['names', 'category', 'usages']) + .then( + response => resolve( + _tagListToMatches(response.results, this._options)), + reject); + }); }; super(input, options); diff --git a/client/js/controls/tag_input_control.js b/client/js/controls/tag_input_control.js index ee864281..13994374 100644 --- a/client/js/controls/tag_input_control.js +++ b/client/js/controls/tag_input_control.js @@ -4,6 +4,7 @@ const api = require('../api.js'); const tags = require('../tags.js'); const misc = require('../util/misc.js'); const uri = require('../util/uri.js'); +const Tag = require('../models/tag.js'); const settings = require('../models/settings.js'); const events = require('../events.js'); const views = require('../util/views.js'); @@ -80,11 +81,12 @@ class SuggestionList { } class TagInputControl extends events.EventTarget { - constructor(hostNode) { + constructor(hostNode, tagList) { super(); - this.tags = []; + this.tags = tagList; this._hostNode = hostNode; this._suggestions = new SuggestionList(); + this._tagToListItemNode = new Map(); // dom const editAreaNode = template(); @@ -98,16 +100,18 @@ class TagInputControl extends events.EventTarget { getTextToFind: () => { return this._tagInputNode.value; }, - confirm: text => { + confirm: tag => { this._tagInputNode.value = ''; - this.addTag(text, SOURCE_USER_INPUT); + // XXX: tags from autocomplete don't contain implications + // so they need to be looked up in API + this.addTagByName(tag.names[0], SOURCE_USER_INPUT); }, - delete: text => { + delete: tag => { this._tagInputNode.value = ''; - this.deleteTag(text); + this.deleteTag(tag); }, verticalShift: -2, - isTaggedWith: tagName => this.isTaggedWith(tagName), + isTaggedWith: tagName => this.tags.isTaggedWith(tagName), }); // dom events @@ -127,114 +131,81 @@ class TagInputControl extends events.EventTarget { this._hostNode.parentNode.insertBefore( this._editAreaNode, hostNode.nextSibling); - this.addEventListener('change', e => this._evtTagsChanged(e)); - this.addEventListener('add', e => this._evtTagAdded(e)); - this.addEventListener('remove', e => this._evtTagRemoved(e)); - // add existing tags - this.addMultipleTags(this._hostNode.value, SOURCE_INIT); + for (let tag of [...this.tags]) { + const listItemNode = this._createListItemNode(tag); + this._tagListNode.appendChild(listItemNode); + } } - isTaggedWith(tagName) { - let actualTag = null; - [tagName, actualTag] = this._transformTagName(tagName); - return this.tags - .map(t => t.toLowerCase()) - .includes(tagName.toLowerCase()); - } - - addMultipleTags(text, source) { + addTagByText(text, source) { for (let tagName of text.split(/\s+/).filter(word => word).reverse()) { - this.addTag(tagName, source); + this.addTagByName(tagName, source); } } - addTag(tagName, source) { - tagName = tags.getOriginalTagName(tagName); - - if (!tagName) { + addTagByName(name, source) { + name = name.trim(); + if (!name) { return; } - - let actualTag = null; - [tagName, actualTag] = this._transformTagName(tagName); - if (!this.isTaggedWith(tagName)) { - this.tags.push(tagName); - } - this.dispatchEvent(new CustomEvent('add', { - detail: { - tagName: tagName, - source: source, - }, - })); - this.dispatchEvent(new CustomEvent('change')); - - // XXX: perhaps we should aggregate suggestions from all implications - // for call to the _suggestRelations - if (source !== SOURCE_INIT && source !== SOURCE_CLIPBOARD) { - for (let otherTagName of tags.getAllImplications(tagName)) { - this.addTag(otherTagName, SOURCE_IMPLICATION); - } - } + return Tag.get(name).then(tag => { + return this.addTag(tag, source); + }, () => { + const tag = new Tag(); + tag.names = [name]; + tag.category = null; + return this.addTag(tag, source); + }); } - deleteTag(tagName) { - if (!tagName) { - return; - } - let actualTag = null; - [tagName, actualTag] = this._transformTagName(tagName); - if (!this.isTaggedWith(tagName)) { - return; - } - this._hideAutoComplete(); - this.tags = this.tags.filter( - t => t.toLowerCase() != tagName.toLowerCase()); - this.dispatchEvent(new CustomEvent('remove', { - detail: { - tagName: tagName, - }, - })); - this.dispatchEvent(new CustomEvent('change')); - } - - _evtTagsChanged(e) { - this._hostNode.value = this.tags.join(' '); - this._hostNode.dispatchEvent(new CustomEvent('change')); - } - - _evtTagAdded(e) { - const tagName = e.detail.tagName; - const actualTag = tags.getTagByName(tagName); - let listItemNode = this._getListItemNodeFromTagName(tagName); - const alreadyAdded = !!listItemNode; - if (alreadyAdded) { - if (e.detail.source !== SOURCE_IMPLICATION) { + addTag(tag, source) { + if (source != SOURCE_INIT && this.tags.isTaggedWith(tag.names[0])) { + const listItemNode = this._getListItemNode(tag); + if (source !== SOURCE_IMPLICATION) { listItemNode.classList.add('duplicate'); + _fadeOutListItemNodeStatus(listItemNode); } - } else { - listItemNode = this._createListItemNode(tagName); - if (!actualTag) { + return Promise.resolve(); + } + + return this.tags.addByName(tag.names[0], false).then(() => { + const listItemNode = this._createListItemNode(tag); + if (!tag.category) { listItemNode.classList.add('new'); } - if (e.detail.source === SOURCE_IMPLICATION) { + if (source === SOURCE_IMPLICATION) { listItemNode.classList.add('implication'); } this._tagListNode.prependChild(listItemNode); - } - _fadeOutListItemNodeStatus(listItemNode); + _fadeOutListItemNodeStatus(listItemNode); - if ([SOURCE_USER_INPUT, SOURCE_SUGGESTION].includes(e.detail.source) && - actualTag) { - this._loadSuggestions(actualTag); - } + return Promise.all( + tag.implications.map( + implication => this.addTagByName( + implication.names[0], SOURCE_IMPLICATION))); + }).then(() => { + this.dispatchEvent(new CustomEvent('add', { + detail: {tag: tag, source: source}, + })); + this.dispatchEvent(new CustomEvent('change')); + return Promise.resolve(); + }); } - _evtTagRemoved(e) { - const listItemNode = this._getListItemNodeFromTagName(e.detail.tagName); - if (listItemNode) { - listItemNode.parentNode.removeChild(listItemNode); + deleteTag(tag) { + if (!this.tags.isTaggedWith(tag.names[0])) { + return; } + this.tags.removeByName(tag.names[0]); + this._hideAutoComplete(); + + this._deleteListItemNode(tag); + + this.dispatchEvent(new CustomEvent('remove', { + detail: {tag: tag}, + })); + this.dispatchEvent(new CustomEvent('change')); } _evtInputPaste(e) { @@ -248,7 +219,7 @@ class TagInputControl extends events.EventTarget { return; } this._hideAutoComplete(); - this.addMultipleTags(pastedText, SOURCE_CLIPBOARD); + this.addTagByText(pastedText, SOURCE_CLIPBOARD); this._tagInputNode.value = ''; } @@ -259,7 +230,7 @@ class TagInputControl extends events.EventTarget { _evtAddTagButtonClick(e) { e.preventDefault(); - this.addTag(this._tagInputNode.value, SOURCE_USER_INPUT); + this.addTagByName(this._tagInputNode.value, SOURCE_USER_INPUT); this._tagInputNode.value = ''; } @@ -272,36 +243,14 @@ class TagInputControl extends events.EventTarget { if (e.which == KEY_RETURN || e.which == KEY_SPACE) { e.preventDefault(); this._hideAutoComplete(); - this.addMultipleTags(this._tagInputNode.value, SOURCE_USER_INPUT); + this.addTagByText(this._tagInputNode.value, SOURCE_USER_INPUT); this._tagInputNode.value = ''; } } - _transformTagName(tagName) { - const actualTag = tags.getTagByName(tagName); - if (actualTag) { - tagName = actualTag.names[0]; - } - return [tagName, actualTag]; - } - - _getListItemNodeFromTagName(tagName) { - let actualTag = null; - [tagName, actualTag] = this._transformTagName(tagName); - for (let listItemNode of this._tagListNode.querySelectorAll('li')) { - if (listItemNode.getAttribute('data-tag').toLowerCase() === - tagName.toLowerCase()) { - return listItemNode; - } - } - return null; - } - - _createListItemNode(tagName) { - let actualTag = null; - [tagName, actualTag] = this._transformTagName(tagName); - const className = actualTag ? - misc.makeCssName(actualTag.category, 'tag') : + _createListItemNode(tag) { + const className = tag.category ? + misc.makeCssName(tag.category, 'tag') : null; const tagLinkNode = document.createElement('a'); @@ -309,7 +258,8 @@ class TagInputControl extends events.EventTarget { tagLinkNode.classList.add(className); } tagLinkNode.setAttribute( - 'href', uri.formatClientLink('tag', tagName)); + 'href', uri.formatClientLink('tag', tag.names[0])); + const tagIconNode = document.createElement('i'); tagIconNode.classList.add('fa'); tagIconNode.classList.add('fa-tag'); @@ -320,13 +270,13 @@ class TagInputControl extends events.EventTarget { searchLinkNode.classList.add(className); } searchLinkNode.setAttribute( - 'href', uri.formatClientLink('posts', {query: tagName})); - searchLinkNode.textContent = tagName + ' '; + 'href', uri.formatClientLink('posts', {query: tag.names[0]})); + searchLinkNode.textContent = tag.names[0] + ' '; searchLinkNode.addEventListener('click', e => { e.preventDefault(); - if (actualTag) { - this._suggestions.clear(); - this._loadSuggestions(actualTag); + this._suggestions.clear(); + if (tag.postCount > 0) { + this._loadSuggestions(tag); this._removeSuggestionsPopupOpacity(); } else { this._closeSuggestionsPopup(); @@ -335,8 +285,7 @@ class TagInputControl extends events.EventTarget { const usagesNode = document.createElement('span'); usagesNode.classList.add('tag-usages'); - usagesNode.setAttribute( - 'data-pseudo-content', actualTag ? actualTag.usages : 0); + usagesNode.setAttribute('data-pseudo-content', tag.postCount); const removalLinkNode = document.createElement('a'); removalLinkNode.classList.add('remove-tag'); @@ -344,18 +293,34 @@ class TagInputControl extends events.EventTarget { removalLinkNode.setAttribute('data-pseudo-content', '×'); removalLinkNode.addEventListener('click', e => { e.preventDefault(); - this.deleteTag(tagName); + this.deleteTag(tag); }); const listItemNode = document.createElement('li'); - listItemNode.setAttribute('data-tag', tagName); listItemNode.appendChild(removalLinkNode); listItemNode.appendChild(tagLinkNode); listItemNode.appendChild(searchLinkNode); listItemNode.appendChild(usagesNode); + for (let name of tag.names) { + this._tagToListItemNode.set(name, listItemNode); + } return listItemNode; } + _deleteListItemNode(tag) { + const listItemNode = this._getListItemNode(tag); + if (listItemNode) { + listItemNode.parentNode.removeChild(listItemNode); + } + for (let name of tag.names) { + this._tagToListItemNode.delete(name); + } + } + + _getListItemNode(tag) { + return this._tagToListItemNode.get(tag.names[0]); + } + _loadSuggestions(tag) { const browsingSettings = settings.get(); if (!browsingSettings.tagSuggestions) { @@ -399,23 +364,22 @@ class TagInputControl extends events.EventTarget { for (let tuple of this._suggestions.getAll()) { const tagName = tuple.tagName; const weight = tuple.weight; - if (this.isTaggedWith(tagName)) { + if (this.tags.isTaggedWith(tagName)) { continue; } - const actualTag = tags.getTagByName(tagName); const addLinkNode = document.createElement('a'); addLinkNode.textContent = tagName; addLinkNode.classList.add('add-tag'); addLinkNode.setAttribute('href', ''); - if (actualTag) { + Tag.get(tagName).then(tag => { addLinkNode.classList.add( - misc.makeCssName(actualTag.category, 'tag')); - } + misc.makeCssName(tag.category, 'tag')); + }); addLinkNode.addEventListener('click', e => { e.preventDefault(); listNode.removeChild(listItemNode); - this.addTag(tagName, SOURCE_SUGGESTION); + this.addTagByName(tagName, SOURCE_SUGGESTION); }); const weightNode = document.createElement('span'); diff --git a/client/js/main.js b/client/js/main.js index 2308ee90..71284f8e 100644 --- a/client/js/main.js +++ b/client/js/main.js @@ -55,7 +55,7 @@ for (let controller of controllers) { const tags = require('./tags.js'); const api = require('./api.js'); -tags.refreshExport(); // we don't care about errors +tags.refreshCategoryColorMap(); // we don't care about errors api.loginFromCookies().then(() => { router.start(); }, error => { diff --git a/client/js/models/abstract_list.js b/client/js/models/abstract_list.js index 6544ed29..10edd612 100644 --- a/client/js/models/abstract_list.js +++ b/client/js/models/abstract_list.js @@ -27,6 +27,13 @@ class AbstractList extends events.EventTarget { return ret; } + sync(plainList) { + this.clear(); + for (let item of (plainList || [])) { + this.add(this.constructor._itemClass.fromResponse(item)); + } + } + add(item) { if (item.addEventListener) { item.addEventListener('delete', e => { @@ -75,6 +82,10 @@ class AbstractList extends events.EventTarget { return this._list[index]; } + map(...args) { + return this._list.map(...args); + } + [Symbol.iterator]() { return this._list[Symbol.iterator](); } diff --git a/client/js/models/post.js b/client/js/models/post.js index 71f3eb98..2b63bd62 100644 --- a/client/js/models/post.js +++ b/client/js/models/post.js @@ -4,23 +4,18 @@ const api = require('../api.js'); const uri = require('../util/uri.js'); const tags = require('../tags.js'); const events = require('../events.js'); +const TagList = require('./tag_list.js'); const NoteList = require('./note_list.js'); const CommentList = require('./comment_list.js'); const misc = require('../util/misc.js'); -function _syncObservableCollection(target, plainList) { - target.clear(); - for (let item of (plainList || [])) { - target.add(target.constructor._itemClass.fromResponse(item)); - } -} - class Post extends events.EventTarget { constructor() { super(); this._orig = {}; for (let obj of [this, this._orig]) { + obj._tags = new TagList(); obj._notes = new NoteList(); obj._comments = new CommentList(); } @@ -56,7 +51,6 @@ class Post extends events.EventTarget { get hasCustomThumbnail() { return this._hasCustomThumbnail; } set flags(value) { this._flags = value; } - set tags(value) { this._tags = value; } set safety(value) { this._safety = value; } set relations(value) { this._relations = value; } set newContent(value) { this._newContent = value; } @@ -94,29 +88,6 @@ class Post extends events.EventTarget { }); } - isTaggedWith(tagName) { - return this._tags - .map(s => s.toLowerCase()) - .includes(tagName.toLowerCase()); - } - - addTag(tagName, addImplications) { - if (this.isTaggedWith(tagName)) { - return; - } - this._tags.push(tagName); - if (addImplications !== false) { - for (let otherTag of tags.getAllImplications(tagName)) { - this.addTag(otherTag, addImplications); - } - } - } - - removeTag(tagName) { - this._tags = this._tags.filter( - s => s.toLowerCase() != tagName.toLowerCase()); - } - save(anonymous) { const files = {}; const detail = {version: this._version}; @@ -132,15 +103,14 @@ class Post extends events.EventTarget { detail.flags = this._flags; } if (misc.arraysDiffer(this._tags, this._orig._tags)) { - detail.tags = this._tags; + detail.tags = this._tags.map(tag => tag.names[0]); } if (misc.arraysDiffer(this._relations, this._orig._relations)) { detail.relations = this._relations; } if (misc.arraysDiffer(this._notes, this._orig._notes)) { - detail.notes = [...this._notes].map(note => ({ - polygon: [...note.polygon].map( - point => [point.x, point.y]), + detail.notes = this._notes.map(note => ({ + polygon: note.polygon.map(point => [point.x, point.y]), text: note.text, })); } @@ -310,7 +280,6 @@ class Post extends events.EventTarget { _fileSize: response.fileSize, _flags: [...response.flags || []], - _tags: [...response.tags || []], _relations: [...response.relations || []], _score: response.score, @@ -322,8 +291,9 @@ class Post extends events.EventTarget { }); for (let obj of [this, this._orig]) { - _syncObservableCollection(obj._notes, response.notes); - _syncObservableCollection(obj._comments, response.comments); + obj._tags.sync(response.tags); + obj._notes.sync(response.notes); + obj._comments.sync(response.comments); } Object.assign(this, map()); diff --git a/client/js/models/tag.js b/client/js/models/tag.js index ccfe6b21..c7435540 100644 --- a/client/js/models/tag.js +++ b/client/js/models/tag.js @@ -7,8 +7,16 @@ const misc = require('../util/misc.js'); class Tag extends events.EventTarget { constructor() { + const TagList = require('./tag_list.js'); + super(); this._orig = {}; + + for (let obj of [this, this._orig]) { + obj._suggestions = new TagList(); + obj._implications = new TagList(); + } + this._updateFromResponse({}); } @@ -24,8 +32,6 @@ class Tag extends events.EventTarget { set names(value) { this._names = value; } set category(value) { this._category = value; } set description(value) { this._description = value; } - set implications(value) { this._implications = value; } - set suggestions(value) { this._suggestions = value; } static fromResponse(response) { const ret = new Tag(); @@ -54,10 +60,12 @@ class Tag extends events.EventTarget { detail.description = this._description; } if (misc.arraysDiffer(this._implications, this._orig._implications)) { - detail.implications = this._implications; + detail.implications = this._implications.map( + relation => relation.names[0]); } if (misc.arraysDiffer(this._suggestions, this._orig._suggestions)) { - detail.suggestions = this._suggestions; + detail.suggestions = this._suggestions.map( + relation => relation.names[0]); } let promise = this._origName ? @@ -124,13 +132,16 @@ class Tag extends events.EventTarget { _names: response.names, _category: response.category, _description: response.description, - _implications: response.implications, - _suggestions: response.suggestions, _creationTime: response.creationTime, _lastEditTime: response.lastEditTime, - _postCount: response.usages, + _postCount: response.usages || 0, }; + for (let obj of [this, this._orig]) { + obj._suggestions.sync(response.suggestions); + obj._implications.sync(response.implications); + } + Object.assign(this, map); Object.assign(this._orig, map); } diff --git a/client/js/models/tag_list.js b/client/js/models/tag_list.js index 282d4ef7..d87b694e 100644 --- a/client/js/models/tag_list.js +++ b/client/js/models/tag_list.js @@ -22,6 +22,48 @@ class TagList extends AbstractList { {results: TagList.fromResponse(response.results)})); }); } + + isTaggedWith(testName) { + for (let tag of this._list) { + for (let tagName of tag.names) { + if (tagName.toLowerCase() === testName.toLowerCase()) { + return true; + } + } + } + return false; + } + + addByName(tagName, addImplications) { + if (this.isTaggedWith(tagName)) { + return Promise.resolve(); + } + + const tag = new Tag(); + tag.names = [tagName]; + + this.add(tag); + + if (addImplications !== false) { + return Tag.get(tagName).then(actualTag => { + return Promise.all( + actualTag.implications.map( + relation => this.addByName(relation.names[0], true))); + }); + } + + return Promise.resolve(); + } + + removeByName(testName) { + for (let tag of this._list) { + for (let tagName of tag.names) { + if (tagName.toLowerCase() === testName.toLowerCase()) { + this.remove(tag); + } + } + } + } } TagList._itemClass = Tag; diff --git a/client/js/tags.js b/client/js/tags.js index c0c301b5..8f83a428 100644 --- a/client/js/tags.js +++ b/client/js/tags.js @@ -1,92 +1,23 @@ 'use strict'; const misc = require('./util/misc.js'); -const request = require('superagent'); +const TagCategoryList = require('./models/tag_category_list.js'); -let _tags = new Map(); -let _categories = new Map(); let _stylesheet = null; -function getTagByName(name) { - return _tags.get(name.toLowerCase()); -} - -function getCategoryByName(name) { - return _categories.get(name.toLowerCase()); -} - -function getNameToTagMap() { - return _tags; -} - -function getAllTags() { - return _tags.values(); -} - -function getAllCategories() { - return _categories.values(); -} - -function getOriginalTagName(name) { - const actualTag = getTagByName(name); - if (actualTag) { - for (let originalName of actualTag.names) { - if (originalName.toLowerCase() == name.toLowerCase()) { - return originalName; - } +function refreshCategoryColorMap() { + return TagCategoryList.get().then(response => { + if (_stylesheet) { + document.head.removeChild(_stylesheet); } - } - return name; -} - -function _tagsToMap(tags) { - let map = new Map(); - for (let tag of tags) { - for (let name of tag.names) { - map.set(name.toLowerCase(), tag); + _stylesheet = document.createElement('style'); + document.head.appendChild(_stylesheet); + for (let category of response.results) { + const ruleName = misc.makeCssName(category.name, 'tag'); + _stylesheet.sheet.insertRule( + `.${ruleName} { color: ${category.color} }`, + _stylesheet.sheet.cssRules.length); } - } - return map; -} - -function _tagCategoriesToMap(categories) { - let map = new Map(); - for (let category of categories) { - map.set(category.name.toLowerCase(), category); - } - return map; -} - -function _refreshStylesheet() { - if (_stylesheet) { - document.head.removeChild(_stylesheet); - } - _stylesheet = document.createElement('style'); - document.head.appendChild(_stylesheet); - for (let category of getAllCategories()) { - const ruleName = misc.makeCssName(category.name, 'tag'); - _stylesheet.sheet.insertRule( - `.${ruleName} { color: ${category.color} }`, - _stylesheet.sheet.cssRules.length); - } -} - -function refreshExport() { - return new Promise((resolve, reject) => { - request.get('/data/tags.json').end((error, response) => { - if (error) { - _tags = new Map(); - _categories = new Map(); - reject(error); - return; - } - _tags = _tagsToMap( - response.body ? response.body.tags : []); - _categories = _tagCategoriesToMap( - response.body ? response.body.categories : []); - _refreshStylesheet(); - resolve(); - }); }); } @@ -107,19 +38,7 @@ function getAllImplications(tagName) { return Array.from(implications); } -function getSuggestions(tagName) { - const actualTag = getTagByName(tagName) || {}; - return actualTag.suggestions || []; -} - module.exports = { - getAllCategories: getAllCategories, - getAllTags: getAllTags, - getTagByName: getTagByName, - getCategoryByName: getCategoryByName, - getNameToTagMap: getNameToTagMap, - getOriginalTagName: getOriginalTagName, - refreshExport: refreshExport, - getAllImplications: getAllImplications, - getSuggestions: getSuggestions, + refreshCategoryColorMap: refreshCategoryColorMap, + getAllImplications: getAllImplications, }; diff --git a/client/js/util/views.js b/client/js/util/views.js index db5ae4b0..b0b7ccec 100644 --- a/client/js/util/views.js +++ b/client/js/util/views.js @@ -3,7 +3,6 @@ require('../util/polyfill.js'); const api = require('../api.js'); const templates = require('../templates.js'); -const tags = require('../tags.js'); const domParser = new DOMParser(); const misc = require('./misc.js'); const uri = require('./uri.js'); @@ -194,13 +193,15 @@ function makePostLink(id, includeHash) { misc.escapeHtml(text); } -function makeTagLink(name, includeHash) { - const tag = tags.getTagByName(name); +function makeTagLink(name, includeHash, includeCount, tag) { const category = tag ? tag.category : 'unknown'; let text = name; if (includeHash === true) { text = '#' + text; } + if (includeCount === true) { + text += ' (' + (tag ? tag.postCount : 0) + ')'; + } return api.hasPrivilege('tags:view') ? makeElement( 'a', diff --git a/client/js/views/home_view.js b/client/js/views/home_view.js index b8df0153..7965e3d2 100644 --- a/client/js/views/home_view.js +++ b/client/js/views/home_view.js @@ -2,6 +2,7 @@ const router = require('../router.js'); const uri = require('../util/uri.js'); +const misc = require('../util/misc.js'); const views = require('../util/views.js'); const PostContentControl = require('../controls/post_content_control.js'); const PostNotesOverlayControl @@ -23,12 +24,16 @@ class HomeView { views.syncScrollPosition(); if (this._formNode) { - this._tagAutoCompleteControl = new TagAutoCompleteControl( - this._searchInputNode); + this._autoCompleteControl = new TagAutoCompleteControl( + this._searchInputNode, + { + confirm: tag => + this._autoCompleteControl.replaceSelectedText( + misc.escapeSearchTerm(tag.names[0]), true), + }); this._formNode.addEventListener( 'submit', e => this._evtFormSubmit(e)); } - } showSuccess(text) { diff --git a/client/js/views/posts_header_view.js b/client/js/views/posts_header_view.js index 3f70ef5a..6b697c4f 100644 --- a/client/js/views/posts_header_view.js +++ b/client/js/views/posts_header_view.js @@ -73,7 +73,12 @@ class BulkTagEditor extends BulkEditor { constructor(hostNode) { super(hostNode); this._autoCompleteControl = new TagAutoCompleteControl( - this._inputNode, {addSpace: false}); + this._inputNode, + { + confirm: tag => + this._autoCompleteControl.replaceSelectedText( + tag.names[0], false), + }); this._hostNode.addEventListener('submit', e => this._evtFormSubmit(e)); } @@ -124,9 +129,13 @@ class PostsHeaderView extends events.EventTarget { this._hostNode = ctx.hostNode; views.replaceContent(this._hostNode, template(ctx)); - this._queryAutoCompleteControl = new TagAutoCompleteControl( + this._autoCompleteControl = new TagAutoCompleteControl( this._queryInputNode, - {addSpace: true, transform: misc.escapeSearchTerm}); + { + confirm: tag => + this._autoCompleteControl.replaceSelectedText( + misc.escapeSearchTerm(tag.names[0]), true), + }); keyboard.bind('p', () => this._focusFirstPostNode()); search.searchInputNodeFocusHelper(this._queryInputNode); @@ -235,7 +244,7 @@ class PostsHeaderView extends events.EventTarget { } _navigate() { - this._queryAutoCompleteControl.hide(); + this._autoCompleteControl.hide(); let parameters = {query: this._queryInputNode.value}; parameters.offset = parameters.query === this._ctx.parameters.query ? this._ctx.parameters.offset : 0; diff --git a/client/js/views/posts_page_view.js b/client/js/views/posts_page_view.js index 26103251..292c6175 100644 --- a/client/js/views/posts_page_view.js +++ b/client/js/views/posts_page_view.js @@ -100,7 +100,7 @@ class PostsPageView extends events.EventTarget { if (tagFlipperNode) { let tagged = true; for (let tag of this._ctx.bulkEdit.tags) { - tagged = tagged & post.isTaggedWith(tag); + tagged = tagged & post.tags.isTaggedWith(tag); } tagFlipperNode.classList.toggle('tagged', tagged); } diff --git a/client/js/views/tag_edit_view.js b/client/js/views/tag_edit_view.js index 600cebc9..77c7cefc 100644 --- a/client/js/views/tag_edit_view.js +++ b/client/js/views/tag_edit_view.js @@ -24,10 +24,12 @@ class TagEditView extends events.EventTarget { } if (this._implicationsFieldNode) { - new TagInputControl(this._implicationsFieldNode); + new TagInputControl( + this._implicationsFieldNode, this._tag.implications); } if (this._suggestionsFieldNode) { - new TagInputControl(this._suggestionsFieldNode); + new TagInputControl( + this._suggestionsFieldNode, this._tag.suggestions); } for (let node of this._formNode.querySelectorAll( diff --git a/client/js/views/tag_merge_view.js b/client/js/views/tag_merge_view.js index 87ffb478..9b49097d 100644 --- a/client/js/views/tag_merge_view.js +++ b/client/js/views/tag_merge_view.js @@ -2,6 +2,7 @@ const config = require('../config.js'); const events = require('../events.js'); +const misc = require('../util/misc.js'); const views = require('../util/views.js'); const TagAutoCompleteControl = require('../controls/tag_auto_complete_control.js'); @@ -19,7 +20,13 @@ class TagMergeView extends events.EventTarget { views.decorateValidator(this._formNode); if (this._targetTagFieldNode) { - new TagAutoCompleteControl(this._targetTagFieldNode); + this._autoCompleteControl = new TagAutoCompleteControl( + this._targetTagFieldNode, + { + confirm: tag => + this._autoCompleteControl.replaceSelectedText( + tag.names[0], false), + }); } this._formNode.addEventListener('submit', e => this._evtSubmit(e)); diff --git a/client/js/views/tags_header_view.js b/client/js/views/tags_header_view.js index 8b1b4767..05b32f32 100644 --- a/client/js/views/tags_header_view.js +++ b/client/js/views/tags_header_view.js @@ -17,8 +17,13 @@ class TagsHeaderView extends events.EventTarget { views.replaceContent(this._hostNode, template(ctx)); if (this._queryInputNode) { - new TagAutoCompleteControl( - this._queryInputNode, {transform: misc.escapeSearchTerm}); + this._autoCompleteControl = new TagAutoCompleteControl( + this._queryInputNode, + { + confirm: tag => + this._autoCompleteControl.replaceSelectedText( + misc.escapeSearchTerm(tag.names[0]), true), + }); } search.searchInputNodeFocusHelper(this._queryInputNode); diff --git a/config.yaml.dist b/config.yaml.dist index abffba0b..42267452 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -106,7 +106,7 @@ privileges: 'tags:edit:description': power 'tags:edit:implications': power 'tags:edit:suggestions': power - 'tags:list': regular # note: will be available as data_url/tags.json anyway + 'tags:list': regular 'tags:view': anonymous 'tags:merge': moderator 'tags:delete': moderator @@ -114,7 +114,7 @@ privileges: 'tag_categories:create': moderator 'tag_categories:edit:name': moderator 'tag_categories:edit:color': moderator - 'tag_categories:list': anonymous # note: will be available as data_url/tags.json anyway + 'tag_categories:list': anonymous 'tag_categories:view': anonymous 'tag_categories:delete': moderator 'tag_categories:set_default': moderator diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 51deed2f..27f10c16 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -73,7 +73,6 @@ def create_post( for tag in new_tags: snapshots.create(tag, None if anonymous else ctx.user) ctx.session.commit() - tags.export_to_json() return _serialize_post(ctx, post) @@ -126,7 +125,6 @@ def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: ctx.session.flush() snapshots.modify(post, ctx.user) ctx.session.commit() - tags.export_to_json() return _serialize_post(ctx, post) @@ -138,7 +136,6 @@ def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: snapshots.delete(post, ctx.user) posts.delete(post) ctx.session.commit() - tags.export_to_json() return {} diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index f02f8030..f9a15e76 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -54,7 +54,6 @@ def create_tag( ctx.session.flush() snapshots.create(tag, ctx.user) ctx.session.commit() - tags.export_to_json() return _serialize(ctx, tag) @@ -95,7 +94,6 @@ def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: ctx.session.flush() snapshots.modify(tag, ctx.user) ctx.session.commit() - tags.export_to_json() return _serialize(ctx, tag) @@ -107,7 +105,6 @@ def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: snapshots.delete(tag, ctx.user) tags.delete(tag) ctx.session.commit() - tags.export_to_json() return {} @@ -125,7 +122,6 @@ def merge_tags( tags.merge_tags(source_tag, target_tag) snapshots.merge(source_tag, target_tag, ctx.user) ctx.session.commit() - tags.export_to_json() return _serialize(ctx, target_tag) diff --git a/server/szurubooru/api/tag_category_api.py b/server/szurubooru/api/tag_category_api.py index 63c05591..07da9993 100644 --- a/server/szurubooru/api/tag_category_api.py +++ b/server/szurubooru/api/tag_category_api.py @@ -31,7 +31,6 @@ def create_tag_category( ctx.session.flush() snapshots.create(category, ctx.user) ctx.session.commit() - tags.export_to_json() return _serialize(ctx, category) @@ -61,7 +60,6 @@ def update_tag_category( ctx.session.flush() snapshots.modify(category, ctx.user) ctx.session.commit() - tags.export_to_json() return _serialize(ctx, category) @@ -75,7 +73,6 @@ def delete_tag_category( tag_categories.delete_category(category) snapshots.delete(category, ctx.user) ctx.session.commit() - tags.export_to_json() return {} @@ -89,5 +86,4 @@ def set_tag_category_as_default( ctx.session.flush() snapshots.modify(category, ctx.user) ctx.session.commit() - tags.export_to_json() return _serialize(ctx, category) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index bb82e57b..406f6e18 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -224,7 +224,13 @@ class PostSerializer(serialization.BaseSerializer): return self.post.flags def serialize_tags(self) -> Any: - return [tag.names[0].name for tag in tags.sort_tags(self.post.tags)] + return [ + { + 'names': [name.name for name in tag.names], + 'category': tag.category.name, + 'usages': tag.post_count, + } + for tag in tags.sort_tags(self.post.tags)] def serialize_relations(self) -> Any: return sorted( diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 55ee632e..7d92f1e7 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -72,6 +72,14 @@ def sort_tags(tags: List[model.Tag]) -> List[model.Tag]: ) +def serialize_relation(tag): + return { + 'names': [tag_name.name for tag_name in tag.names], + 'category': tag.category.name, + 'usages': tag.post_count, + } + + class TagSerializer(serialization.BaseSerializer): def __init__(self, tag: model.Tag) -> None: self.tag = tag @@ -112,12 +120,12 @@ class TagSerializer(serialization.BaseSerializer): def serialize_suggestions(self) -> Any: return [ - relation.names[0].name + serialize_relation(relation) for relation in sort_tags(self.tag.suggestions)] def serialize_implications(self) -> Any: return [ - relation.names[0].name + serialize_relation(relation) for relation in sort_tags(self.tag.implications)] @@ -128,67 +136,6 @@ def serialize_tag( return TagSerializer(tag).serialize(options) -def export_to_json() -> None: - tags = {} # type: Dict[int, Any] - categories = {} # type: Dict[int, Any] - - for result in db.session.query( - model.TagCategory.tag_category_id, - model.TagCategory.name, - model.TagCategory.color).all(): - categories[result[0]] = { - 'name': result[1], - 'color': result[2], - } - - for result in ( - db.session - .query(model.TagName.tag_id, model.TagName.name) - .order_by(model.TagName.order) - .all()): - if not result[0] in tags: - tags[result[0]] = {'names': []} - tags[result[0]]['names'].append(result[1]) - - for result in ( - db.session - .query(model.TagSuggestion.parent_id, model.TagName.name) - .join( - model.TagName, - model.TagName.tag_id == model.TagSuggestion.child_id) - .all()): - if 'suggestions' not in tags[result[0]]: - tags[result[0]]['suggestions'] = [] - tags[result[0]]['suggestions'].append(result[1]) - - for result in ( - db.session - .query(model.TagImplication.parent_id, model.TagName.name) - .join( - model.TagName, - model.TagName.tag_id == model.TagImplication.child_id) - .all()): - if 'implications' not in tags[result[0]]: - tags[result[0]]['implications'] = [] - tags[result[0]]['implications'].append(result[1]) - - for result in db.session.query( - model.Tag.tag_id, - model.Tag.category_id, - model.Tag.post_count).all(): - tags[result[0]]['category'] = categories[result[1]]['name'] - tags[result[0]]['usages'] = result[2] - - output = { - 'categories': list(categories.values()), - 'tags': list(tags.values()), - } - - export_path = os.path.join(config.config['data_dir'], 'tags.json') - with open(export_path, 'w') as handle: - handle.write(json.dumps(output, separators=(',', ':'))) - - def try_get_tag_by_name(name: str) -> Optional[model.Tag]: return ( db.session diff --git a/server/szurubooru/search/executor.py b/server/szurubooru/search/executor.py index 0a8202ce..10b34b1c 100644 --- a/server/szurubooru/search/executor.py +++ b/server/szurubooru/search/executor.py @@ -134,7 +134,7 @@ class Executor: 'offset': offset, 'limit': limit, 'total': count, - 'results': [serializer(entity) for entity in entities], + 'results': list([serializer(entity) for entity in entities]), } def _prepare_db_query( diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py index a0561a34..6edcdd36 100644 --- a/server/szurubooru/tests/api/test_post_creating.py +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -30,7 +30,6 @@ def test_creating_minimal_posts( patch('szurubooru.func.posts.update_post_flags'), \ patch('szurubooru.func.posts.update_post_thumbnail'), \ patch('szurubooru.func.posts.serialize_post'), \ - patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.snapshots.create'): posts.create_post.return_value = (post, []) posts.serialize_post.return_value = 'serialized post' @@ -62,7 +61,6 @@ def test_creating_minimal_posts( posts.serialize_post.assert_called_once_with( post, auth_user, options=[]) snapshots.create.assert_called_once_with(post, auth_user) - tags.export_to_json.assert_called_once_with() def test_creating_full_posts(context_factory, post_factory, user_factory): @@ -78,7 +76,6 @@ def test_creating_full_posts(context_factory, post_factory, user_factory): patch('szurubooru.func.posts.update_post_notes'), \ patch('szurubooru.func.posts.update_post_flags'), \ patch('szurubooru.func.posts.serialize_post'), \ - patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.snapshots.create'): posts.create_post.return_value = (post, []) posts.serialize_post.return_value = 'serialized post' @@ -111,7 +108,6 @@ def test_creating_full_posts(context_factory, post_factory, user_factory): posts.serialize_post.assert_called_once_with( post, auth_user, options=[]) snapshots.create.assert_called_once_with(post, auth_user) - tags.export_to_json.assert_called_once_with() def test_anonymous_uploads( @@ -121,8 +117,7 @@ def test_anonymous_uploads( db.session.add(post) db.session.flush() - with patch('szurubooru.func.tags.export_to_json'), \ - patch('szurubooru.func.posts.serialize_post'), \ + with patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.update_post_source'): config_injector({ @@ -152,7 +147,6 @@ def test_creating_from_url_saves_source( db.session.flush() with patch('szurubooru.func.net.download'), \ - patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.update_post_source'): @@ -183,7 +177,6 @@ def test_creating_from_url_with_source_specified( db.session.flush() with patch('szurubooru.func.net.download'), \ - patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.update_post_source'): @@ -245,7 +238,6 @@ def test_omitting_optional_field( patch('szurubooru.func.posts.update_post_notes'), \ patch('szurubooru.func.posts.update_post_flags'), \ patch('szurubooru.func.posts.serialize_post'), \ - patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.snapshots.create'): posts.create_post.return_value = (post, []) posts.serialize_post.return_value = 'serialized post' diff --git a/server/szurubooru/tests/api/test_post_deleting.py b/server/szurubooru/tests/api/test_post_deleting.py index 643b952c..e35a4488 100644 --- a/server/szurubooru/tests/api/test_post_deleting.py +++ b/server/szurubooru/tests/api/test_post_deleting.py @@ -14,15 +14,13 @@ def test_deleting(user_factory, post_factory, context_factory): post = post_factory(id=1) db.session.add(post) db.session.flush() - with patch('szurubooru.func.tags.export_to_json'), \ - patch('szurubooru.func.snapshots.delete'): + with patch('szurubooru.func.snapshots.delete'): result = api.post_api.delete_post( context_factory(params={'version': 1}, user=auth_user), {'post_id': 1}) assert result == {} assert db.session.query(model.Post).count() == 0 snapshots.delete.assert_called_once_with(post, auth_user) - tags.export_to_json.assert_called_once_with() def test_trying_to_delete_non_existing(user_factory, context_factory): diff --git a/server/szurubooru/tests/api/test_post_updating.py b/server/szurubooru/tests/api/test_post_updating.py index d3649307..fa298eaa 100644 --- a/server/szurubooru/tests/api/test_post_updating.py +++ b/server/szurubooru/tests/api/test_post_updating.py @@ -39,7 +39,6 @@ def test_post_updating( patch('szurubooru.func.posts.update_post_notes'), \ patch('szurubooru.func.posts.update_post_flags'), \ patch('szurubooru.func.posts.serialize_post'), \ - patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.snapshots.modify'), \ fake_datetime('1997-01-01'): posts.serialize_post.return_value = 'serialized post' @@ -78,7 +77,6 @@ def test_post_updating( posts.serialize_post.assert_called_once_with( post, auth_user, options=[]) snapshots.modify.assert_called_once_with(post, auth_user) - tags.export_to_json.assert_called_once_with() assert post.last_edit_time == datetime(1997, 1, 1) @@ -88,7 +86,6 @@ def test_uploading_from_url_saves_source( db.session.add(post) db.session.flush() with patch('szurubooru.func.net.download'), \ - patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.update_post_content'), \ patch('szurubooru.func.posts.update_post_source'), \ @@ -110,7 +107,6 @@ def test_uploading_from_url_with_source_specified( db.session.add(post) db.session.flush() with patch('szurubooru.func.net.download'), \ - patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.update_post_content'), \ patch('szurubooru.func.posts.update_post_source'), \ diff --git a/server/szurubooru/tests/api/test_tag_category_creating.py b/server/szurubooru/tests/api/test_tag_category_creating.py index fbd8b1bc..47e8405c 100644 --- a/server/szurubooru/tests/api/test_tag_category_creating.py +++ b/server/szurubooru/tests/api/test_tag_category_creating.py @@ -24,8 +24,7 @@ def test_creating_category( with patch('szurubooru.func.tag_categories.create_category'), \ patch('szurubooru.func.tag_categories.serialize_category'), \ patch('szurubooru.func.tag_categories.update_category_name'), \ - patch('szurubooru.func.snapshots.create'), \ - patch('szurubooru.func.tags.export_to_json'): + patch('szurubooru.func.snapshots.create'): tag_categories.create_category.return_value = category tag_categories.update_category_name.side_effect = _update_category_name tag_categories.serialize_category.return_value = 'serialized category' @@ -35,7 +34,6 @@ def test_creating_category( assert result == 'serialized category' tag_categories.create_category.assert_called_once_with('meta', 'black') snapshots.create.assert_called_once_with(category, auth_user) - tags.export_to_json.assert_called_once_with() @pytest.mark.parametrize('field', ['name', 'color']) diff --git a/server/szurubooru/tests/api/test_tag_category_deleting.py b/server/szurubooru/tests/api/test_tag_category_deleting.py index 1fc86431..2bee5137 100644 --- a/server/szurubooru/tests/api/test_tag_category_deleting.py +++ b/server/szurubooru/tests/api/test_tag_category_deleting.py @@ -17,8 +17,7 @@ def test_deleting(user_factory, tag_category_factory, context_factory): db.session.add(tag_category_factory(name='root')) db.session.add(category) db.session.flush() - with patch('szurubooru.func.snapshots.delete'), \ - patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.snapshots.delete'): result = api.tag_category_api.delete_tag_category( context_factory(params={'version': 1}, user=auth_user), {'category_name': 'category'}) @@ -26,7 +25,6 @@ def test_deleting(user_factory, tag_category_factory, context_factory): assert db.session.query(model.TagCategory).count() == 1 assert db.session.query(model.TagCategory).one().name == 'root' snapshots.delete.assert_called_once_with(category, auth_user) - tags.export_to_json.assert_called_once_with() def test_trying_to_delete_used( diff --git a/server/szurubooru/tests/api/test_tag_category_updating.py b/server/szurubooru/tests/api/test_tag_category_updating.py index d406dd1f..24a9f6ed 100644 --- a/server/szurubooru/tests/api/test_tag_category_updating.py +++ b/server/szurubooru/tests/api/test_tag_category_updating.py @@ -27,8 +27,7 @@ def test_simple_updating(user_factory, tag_category_factory, context_factory): with patch('szurubooru.func.tag_categories.serialize_category'), \ patch('szurubooru.func.tag_categories.update_category_name'), \ patch('szurubooru.func.tag_categories.update_category_color'), \ - patch('szurubooru.func.snapshots.modify'), \ - patch('szurubooru.func.tags.export_to_json'): + patch('szurubooru.func.snapshots.modify'): tag_categories.update_category_name.side_effect = _update_category_name tag_categories.serialize_category.return_value = 'serialized category' result = api.tag_category_api.update_tag_category( @@ -42,7 +41,6 @@ def test_simple_updating(user_factory, tag_category_factory, context_factory): tag_categories.update_category_color.assert_called_once_with( category, 'white') snapshots.modify.assert_called_once_with(category, auth_user) - tags.export_to_json.assert_called_once_with() @pytest.mark.parametrize('field', ['name', 'color']) @@ -56,8 +54,7 @@ def test_omitting_optional_field( } del params[field] with patch('szurubooru.func.tag_categories.serialize_category'), \ - patch('szurubooru.func.tag_categories.update_category_name'), \ - patch('szurubooru.func.tags.export_to_json'): + patch('szurubooru.func.tag_categories.update_category_name'): api.tag_category_api.update_tag_category( context_factory( params={**params, **{'version': 1}}, @@ -95,8 +92,7 @@ def test_set_as_default(user_factory, tag_category_factory, context_factory): db.session.add(category) db.session.commit() with patch('szurubooru.func.tag_categories.serialize_category'), \ - patch('szurubooru.func.tag_categories.set_default_category'), \ - patch('szurubooru.func.tags.export_to_json'): + patch('szurubooru.func.tag_categories.set_default_category'): tag_categories.update_category_name.side_effect = _update_category_name tag_categories.serialize_category.return_value = 'serialized category' result = api.tag_category_api.set_tag_category_as_default( diff --git a/server/szurubooru/tests/api/test_tag_creating.py b/server/szurubooru/tests/api/test_tag_creating.py index 0c8d5824..4cee7101 100644 --- a/server/szurubooru/tests/api/test_tag_creating.py +++ b/server/szurubooru/tests/api/test_tag_creating.py @@ -15,8 +15,7 @@ def test_creating_simple_tags(tag_factory, user_factory, context_factory): with patch('szurubooru.func.tags.create_tag'), \ patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ patch('szurubooru.func.tags.serialize_tag'), \ - patch('szurubooru.func.snapshots.create'), \ - patch('szurubooru.func.tags.export_to_json'): + patch('szurubooru.func.snapshots.create'): tags.get_or_create_tags_by_names.return_value = ([], []) tags.create_tag.return_value = tag tags.serialize_tag.return_value = 'serialized tag' @@ -34,7 +33,6 @@ def test_creating_simple_tags(tag_factory, user_factory, context_factory): tags.create_tag.assert_called_once_with( ['tag1', 'tag2'], 'meta', ['sug1', 'sug2'], ['imp1', 'imp2']) snapshots.create.assert_called_once_with(tag, auth_user) - tags.export_to_json.assert_called_once_with() @pytest.mark.parametrize('field', ['names', 'category']) @@ -64,8 +62,7 @@ def test_omitting_optional_field( } del params[field] with patch('szurubooru.func.tags.create_tag'), \ - patch('szurubooru.func.tags.serialize_tag'), \ - patch('szurubooru.func.tags.export_to_json'): + patch('szurubooru.func.tags.serialize_tag'): tags.create_tag.return_value = tag_factory() api.tag_api.create_tag( context_factory( diff --git a/server/szurubooru/tests/api/test_tag_deleting.py b/server/szurubooru/tests/api/test_tag_deleting.py index fbd35e12..a0367f22 100644 --- a/server/szurubooru/tests/api/test_tag_deleting.py +++ b/server/szurubooru/tests/api/test_tag_deleting.py @@ -14,15 +14,13 @@ def test_deleting(user_factory, tag_factory, context_factory): tag = tag_factory(names=['tag']) db.session.add(tag) db.session.commit() - with patch('szurubooru.func.tags.export_to_json'), \ - patch('szurubooru.func.snapshots.delete'): + with patch('szurubooru.func.snapshots.delete'): result = api.tag_api.delete_tag( context_factory(params={'version': 1}, user=auth_user), {'tag_name': 'tag'}) assert result == {} assert db.session.query(model.Tag).count() == 0 snapshots.delete.assert_called_once_with(tag, auth_user) - tags.export_to_json.assert_called_once_with() def test_deleting_used( @@ -32,15 +30,14 @@ def test_deleting_used( post.tags.append(tag) db.session.add_all([tag, post]) db.session.commit() - with patch('szurubooru.func.tags.export_to_json'): - api.tag_api.delete_tag( - context_factory( - params={'version': 1}, - user=user_factory(rank=model.User.RANK_REGULAR)), - {'tag_name': 'tag'}) - db.session.refresh(post) - assert db.session.query(model.Tag).count() == 0 - assert post.tags == [] + api.tag_api.delete_tag( + context_factory( + params={'version': 1}, + user=user_factory(rank=model.User.RANK_REGULAR)), + {'tag_name': 'tag'}) + db.session.refresh(post) + assert db.session.query(model.Tag).count() == 0 + assert post.tags == [] def test_trying_to_delete_non_existing(user_factory, context_factory): diff --git a/server/szurubooru/tests/api/test_tag_merging.py b/server/szurubooru/tests/api/test_tag_merging.py index 484fbfa6..671e2e45 100644 --- a/server/szurubooru/tests/api/test_tag_merging.py +++ b/server/szurubooru/tests/api/test_tag_merging.py @@ -25,8 +25,7 @@ def test_merging(user_factory, tag_factory, context_factory, post_factory): assert target_tag.post_count == 0 with patch('szurubooru.func.tags.serialize_tag'), \ patch('szurubooru.func.tags.merge_tags'), \ - patch('szurubooru.func.snapshots.merge'), \ - patch('szurubooru.func.tags.export_to_json'): + patch('szurubooru.func.snapshots.merge'): api.tag_api.merge_tags( context_factory( params={ @@ -39,7 +38,6 @@ def test_merging(user_factory, tag_factory, context_factory, post_factory): tags.merge_tags.called_once_with(source_tag, target_tag) snapshots.merge.assert_called_once_with( source_tag, target_tag, auth_user) - tags.export_to_json.assert_called_once_with() @pytest.mark.parametrize( diff --git a/server/szurubooru/tests/api/test_tag_updating.py b/server/szurubooru/tests/api/test_tag_updating.py index fb63e353..0b8aa703 100644 --- a/server/szurubooru/tests/api/test_tag_updating.py +++ b/server/szurubooru/tests/api/test_tag_updating.py @@ -31,8 +31,7 @@ def test_simple_updating(user_factory, tag_factory, context_factory): patch('szurubooru.func.tags.update_tag_suggestions'), \ patch('szurubooru.func.tags.update_tag_implications'), \ patch('szurubooru.func.tags.serialize_tag'), \ - patch('szurubooru.func.snapshots.modify'), \ - patch('szurubooru.func.tags.export_to_json'): + patch('szurubooru.func.snapshots.modify'): tags.get_or_create_tags_by_names.return_value = ([], []) tags.serialize_tag.return_value = 'serialized tag' result = api.tag_api.update_tag( @@ -58,7 +57,6 @@ def test_simple_updating(user_factory, tag_factory, context_factory): tag, ['imp1', 'imp2']) tags.serialize_tag.assert_called_once_with(tag, options=[]) snapshots.modify.assert_called_once_with(tag, auth_user) - tags.export_to_json.assert_called_once_with() @pytest.mark.parametrize( @@ -84,8 +82,7 @@ def test_omitting_optional_field( with patch('szurubooru.func.tags.create_tag'), \ patch('szurubooru.func.tags.update_tag_names'), \ patch('szurubooru.func.tags.update_tag_category_name'), \ - patch('szurubooru.func.tags.serialize_tag'), \ - patch('szurubooru.func.tags.export_to_json'): + patch('szurubooru.func.tags.serialize_tag'): api.tag_api.update_tag( context_factory( params={**params, **{'version': 1}}, diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 2c41d96d..e296ec0a 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -75,7 +75,11 @@ def test_serialize_post_when_empty(): def test_serialize_post( - user_factory, comment_factory, tag_factory, config_injector): + user_factory, + comment_factory, + tag_factory, + tag_category_factory, + config_injector): config_injector({'data_url': 'http://example.com/', 'secret': 'test'}) with patch('szurubooru.func.comments.serialize_comment'), \ patch('szurubooru.func.users.serialize_micro_user'), \ @@ -92,8 +96,12 @@ def test_serialize_post( post.creation_time = datetime(1997, 1, 1) post.last_edit_time = datetime(1998, 1, 1) post.tags = [ - tag_factory(names=['tag1', 'tag2']), - tag_factory(names=['tag3']) + tag_factory( + names=['tag1', 'tag2'], + category=tag_category_factory('test-cat1')), + tag_factory( + names=['tag3'], + category=tag_category_factory('test-cat2')) ] post.safety = model.Post.SAFETY_SAFE post.source = '4gag' @@ -143,7 +151,7 @@ def test_serialize_post( db.session.flush() result = posts.serialize_post(post, auth_user) - result['tags'].sort() + result['tags'].sort(key=lambda tag: tag['names'][0]) assert result == { 'id': 1, @@ -162,7 +170,17 @@ def test_serialize_post( 'http://example.com/' 'generated-thumbnails/1_244c8840887984c4.jpg', 'flags': ['loop'], - 'tags': ['tag1', 'tag3'], + 'tags': [ + { + 'names': ['tag1', 'tag2'], + 'category': 'test-cat1', 'usages': 1, + }, + { + 'names': ['tag3'], + 'category': 'test-cat2', + 'usages': 1, + }, + ], 'relations': [], 'notes': [], 'user': 'post author', diff --git a/server/szurubooru/tests/func/test_tags.py b/server/szurubooru/tests/func/test_tags.py index 712c8e38..2f888ef5 100644 --- a/server/szurubooru/tests/func/test_tags.py +++ b/server/szurubooru/tests/func/test_tags.py @@ -45,15 +45,18 @@ def test_serialize_tag_when_empty(): def test_serialize_tag(post_factory, tag_factory, tag_category_factory): - tag = tag_factory( - names=['tag1', 'tag2'], - category=tag_category_factory(name='cat')) + cat = tag_category_factory(name='cat') + tag = tag_factory(names=['tag1', 'tag2'], category=cat) tag.tag_id = 1 tag.description = 'description' tag.suggestions = [ - tag_factory(names=['sug1']), tag_factory(names=['sug2'])] + tag_factory(names=['sug1'], category=cat), + tag_factory(names=['sug2'], category=cat), + ] tag.implications = [ - tag_factory(names=['impl1']), tag_factory(names=['impl2'])] + tag_factory(names=['impl1'], category=cat), + tag_factory(names=['impl2'], category=cat), + ] tag.last_edit_time = datetime(1998, 1, 1) post1 = post_factory() post2 = post_factory() @@ -62,8 +65,8 @@ def test_serialize_tag(post_factory, tag_factory, tag_category_factory): db.session.add_all([tag, post1, post2]) db.session.flush() result = tags.serialize_tag(tag) - result['suggestions'].sort() - result['implications'].sort() + result['suggestions'].sort(key=lambda relation: relation['names'][0]) + result['implications'].sort(key=lambda relation: relation['names'][0]) assert result == { 'names': ['tag1', 'tag2'], 'version': 1, @@ -71,69 +74,18 @@ def test_serialize_tag(post_factory, tag_factory, tag_category_factory): 'creationTime': datetime(1996, 1, 1, 0, 0), 'lastEditTime': datetime(1998, 1, 1, 0, 0), 'description': 'description', - 'suggestions': ['sug1', 'sug2'], - 'implications': ['impl1', 'impl2'], + 'suggestions': [ + {'names': ['sug1'], 'category': 'cat', 'usages': 0}, + {'names': ['sug2'], 'category': 'cat', 'usages': 0}, + ], + 'implications': [ + {'names': ['impl1'], 'category': 'cat', 'usages': 0}, + {'names': ['impl2'], 'category': 'cat', 'usages': 0}, + ], 'usages': 2, } -def test_export_to_json( - tmpdir, - query_counter, - config_injector, - post_factory, - tag_factory, - tag_category_factory): - config_injector({'data_dir': str(tmpdir)}) - cat1 = tag_category_factory(name='cat1', color='black') - cat2 = tag_category_factory(name='cat2', color='white') - tag = tag_factory(names=['alias1', 'alias2'], category=cat2) - tag.suggestions = [ - tag_factory(names=['sug1'], category=cat1), - tag_factory(names=['sug2'], category=cat1), - ] - tag.implications = [ - tag_factory(names=['imp1'], category=cat1), - tag_factory(names=['imp2'], category=cat1), - ] - post = post_factory() - post.tags = [tag] - db.session.add_all([post, tag]) - db.session.flush() - - with query_counter: - tags.export_to_json() - assert len(query_counter.statements) == 5 - - export_path = os.path.join(str(tmpdir), 'tags.json') - assert os.path.exists(export_path) - with open(export_path, 'r') as handle: - actual_json = json.loads(handle.read()) - assert actual_json['tags'] - assert actual_json['categories'] - actual_json['tags'].sort(key=lambda tag: tag['names'][0]) - actual_json['categories'].sort(key=lambda category: category['name']) - assert actual_json == { - 'tags': [ - { - 'names': ['alias1', 'alias2'], - 'usages': 1, - 'category': 'cat2', - 'suggestions': ['sug1', 'sug2'], - 'implications': ['imp1', 'imp2'], - }, - {'names': ['imp1'], 'usages': 0, 'category': 'cat1'}, - {'names': ['imp2'], 'usages': 0, 'category': 'cat1'}, - {'names': ['sug1'], 'usages': 0, 'category': 'cat1'}, - {'names': ['sug2'], 'usages': 0, 'category': 'cat1'}, - ], - 'categories': [ - {'name': 'cat1', 'color': 'black'}, - {'name': 'cat2', 'color': 'white'}, - ] - } - - @pytest.mark.parametrize('name_to_search,expected_to_find', [ ('name', True), ('NAME', True), From 36698cddc2fa8cf16d40c20e8dc98ffb141d6493 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 1 Oct 2017 22:00:42 +0200 Subject: [PATCH 089/159] client/posts: fix promise chaining --- client/js/controllers/post_list_controller.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/js/controllers/post_list_controller.js b/client/js/controllers/post_list_controller.js index 039ca187..f60cb8b7 100644 --- a/client/js/controllers/post_list_controller.js +++ b/client/js/controllers/post_list_controller.js @@ -65,7 +65,7 @@ class PostListController { Promise.all( this._bulkEditTags.map(tag => e.detail.post.tags.addByName(tag))) - .then(() => { e.detail.post.save(); }) + .then(e.detail.post.save()) .catch(error => window.alert(error.message)); } From 4848bee5e39bb7ca7cd100054b9b296abb1b2905 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 1 Oct 2017 22:02:49 +0200 Subject: [PATCH 090/159] client/tags: remove unused cruft --- .../js/controls/tag_auto_complete_control.js | 4 ---- client/js/tags.js | 18 ------------------ client/js/views/tag_merge_view.js | 1 - 3 files changed, 23 deletions(-) diff --git a/client/js/controls/tag_auto_complete_control.js b/client/js/controls/tag_auto_complete_control.js index b7dda201..3d9130ec 100644 --- a/client/js/controls/tag_auto_complete_control.js +++ b/client/js/controls/tag_auto_complete_control.js @@ -5,10 +5,6 @@ const views = require('../util/views.js'); const TagList = require('../models/tag_list.js'); const AutoCompleteControl = require('./auto_complete_control.js'); -function _escapeSearch(text) { - return text.replace('\\', '\\\\').replace(':', '\\:'); -} - function _tagListToMatches(tags, options) { return [...tags].sort((tag1, tag2) => { return tag2.usages - tag1.usages; diff --git a/client/js/tags.js b/client/js/tags.js index 8f83a428..c037f6f6 100644 --- a/client/js/tags.js +++ b/client/js/tags.js @@ -21,24 +21,6 @@ function refreshCategoryColorMap() { }); } -function getAllImplications(tagName) { - let implications = []; - let check = [tagName]; - while (check.length) { - let tagName = check.pop(); - const actualTag = getTagByName(tagName) || {}; - for (let implication of actualTag.implications || []) { - if (implications.includes(implication)) { - continue; - } - implications.push(implication); - check.push(implication); - } - } - return Array.from(implications); -} - module.exports = { refreshCategoryColorMap: refreshCategoryColorMap, - getAllImplications: getAllImplications, }; diff --git a/client/js/views/tag_merge_view.js b/client/js/views/tag_merge_view.js index 9b49097d..12286800 100644 --- a/client/js/views/tag_merge_view.js +++ b/client/js/views/tag_merge_view.js @@ -2,7 +2,6 @@ const config = require('../config.js'); const events = require('../events.js'); -const misc = require('../util/misc.js'); const views = require('../util/views.js'); const TagAutoCompleteControl = require('../controls/tag_auto_complete_control.js'); From cdf454818cd42ec249eaa45d6a35edc823b2fafa Mon Sep 17 00:00:00 2001 From: rr- Date: Mon, 2 Oct 2017 21:08:01 +0200 Subject: [PATCH 091/159] client: widen search inputs to match post search --- client/css/tag-list-view.styl | 2 +- client/css/user-list-view.styl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client/css/tag-list-view.styl b/client/css/tag-list-view.styl index 8ae7bcb8..ac4b73de 100644 --- a/client/css/tag-list-view.styl +++ b/client/css/tag-list-view.styl @@ -46,7 +46,7 @@ form width: auto input[name=search-text] - max-width: 15em + width: 25em .append font-size: 0.95em color: $inactive-link-color diff --git a/client/css/user-list-view.styl b/client/css/user-list-view.styl index 3d49d941..f575a97c 100644 --- a/client/css/user-list-view.styl +++ b/client/css/user-list-view.styl @@ -33,7 +33,7 @@ form width: auto input[name=search-text] - max-width: 15em + width: 25em .append font-size: 0.95em color: $inactive-link-color From f8c7375b01314589d1b227be29e239e02896af23 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 8 Oct 2017 21:38:28 +0200 Subject: [PATCH 092/159] server/tags: allow uppercase tag category colors i.e. colors such as "#FF0000" --- server/szurubooru/func/tag_categories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index bec2f0de..f1951a8c 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -105,7 +105,7 @@ def update_category_color(category: model.TagCategory, color: str) -> None: assert category if not color: raise InvalidTagCategoryColorError('Color cannot be empty.') - if not re.match(r'^#?[0-9a-z]+$', color): + if not re.match(r'^#?[0-9A-Za-z]+$', color): raise InvalidTagCategoryColorError('Invalid color.') if util.value_exceeds_column_size(color, model.TagCategory.color): raise InvalidTagCategoryColorError('Color is too long.') From 85cb3d47022f278434a65230bb8e053321c563d4 Mon Sep 17 00:00:00 2001 From: Michael Serajnik Date: Sat, 2 Dec 2017 23:33:21 +0100 Subject: [PATCH 093/159] client/help: fix spelling issues --- client/html/help_keyboard.tpl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client/html/help_keyboard.tpl b/client/html/help_keyboard.tpl index 76bfe2a2..f200ce02 100644 --- a/client/html/help_keyboard.tpl +++ b/client/html/help_keyboard.tpl @@ -41,7 +41,7 @@ shortcuts:

    -

    Additionally, each item in top navigation can be accessed using feature -called “access keys”. Pressing underlined letter while holding -Shfit or Alt+Shift (depending on your browser) will go to the desired page -(most browsers) or focus the link (IE).

    +

    Additionally, each item in the top navigation can be accessed using a +feature called “access keys”. Pressing the underlined letter while +holding Shift or Alt+Shift (depending on your browser) will go to the desired +page (most browsers) or focus the link (IE).

    From 69421464f69d3ed63c329c548b4758eb16d3fd98 Mon Sep 17 00:00:00 2001 From: Michael Serajnik Date: Fri, 15 Dec 2017 19:53:52 +0100 Subject: [PATCH 094/159] client/posts: override resize mode in home view --- client/js/controls/post_content_control.js | 9 +++++++-- client/js/views/home_view.js | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/client/js/controls/post_content_control.js b/client/js/controls/post_content_control.js index ca2fca85..856fb775 100644 --- a/client/js/controls/post_content_control.js +++ b/client/js/controls/post_content_control.js @@ -5,18 +5,23 @@ const views = require('../util/views.js'); const optimizedResize = require('../util/optimized_resize.js'); class PostContentControl { - constructor(hostNode, post, viewportSizeCalculator) { + constructor(hostNode, post, viewportSizeCalculator, fitFunctionOverride) { this._post = post; this._viewportSizeCalculator = viewportSizeCalculator; this._hostNode = hostNode; this._template = views.getTemplate('post-content'); + let fitMode = settings.get().fitMode; + if (typeof fitFunctionOverride !== 'undefined') { + fitMode = fitFunctionOverride; + } + this._currentFitFunction = { 'fit-both': this.fitBoth, 'fit-original': this.fitOriginal, 'fit-width': this.fitWidth, 'fit-height': this.fitHeight, - }[settings.get().fitMode] || this.fitBoth; + }[fitMode] || this.fitBoth; this._install(); diff --git a/client/js/views/home_view.js b/client/js/views/home_view.js index 7965e3d2..c9267053 100644 --- a/client/js/views/home_view.js +++ b/client/js/views/home_view.js @@ -62,7 +62,8 @@ class HomeView { window.innerWidth * 0.8, window.innerHeight * 0.7, ]; - }); + }, + 'fit-both'); this._postNotesOverlay = new PostNotesOverlayControl( this._postContainerNode.querySelector('.post-overlay'), From 59d8b0d4c579771d3bfa2f757fe3498f3de3ddfe Mon Sep 17 00:00:00 2001 From: rr- Date: Sat, 6 Jan 2018 21:35:33 +0100 Subject: [PATCH 095/159] client: update dependencies --- client/package-lock.json | 45 ++++++++++++++++++++++------------------ client/package.json | 10 ++++----- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/client/package-lock.json b/client/package-lock.json index a4d2bf9f..d54d15bc 100644 --- a/client/package-lock.json +++ b/client/package-lock.json @@ -840,9 +840,9 @@ "integrity": "sha1-fB0W1nmhu+WcoCys7PsBHiAfWh8=" }, "camelcase-keys": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/camelcase-keys/-/camelcase-keys-4.1.0.tgz", - "integrity": "sha1-IU00jMVFfzkxaiwxzD43JGMl5z8=", + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/camelcase-keys/-/camelcase-keys-4.2.0.tgz", + "integrity": "sha1-oqpfsa9oh1glnDLBQUJteJI7m3c=", "requires": { "camelcase": "4.1.0", "map-obj": "2.0.0", @@ -1517,9 +1517,9 @@ "integrity": "sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE=" }, "js-cookie": { - "version": "2.1.4", - "resolved": "https://registry.npmjs.org/js-cookie/-/js-cookie-2.1.4.tgz", - "integrity": "sha1-2k7FA4ZvFJ0WTPJfV57zEBUCXY0=" + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/js-cookie/-/js-cookie-2.2.0.tgz", + "integrity": "sha1-Gywnmm7s44ChIWi5JIUmWzWx7/s=" }, "js-tokens": { "version": "3.0.2", @@ -1527,9 +1527,9 @@ "integrity": "sha1-mGbfOVECEw449/mWvOtlRDIJwls=" }, "js-yaml": { - "version": "3.9.1", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.9.1.tgz", - "integrity": "sha512-CbcG379L1e+mWBnLvHWWeLs8GyV/EMw862uLI3c+GxVyDHWZcjZinwuBd3iW2pgxgIlksW/1vNJa4to+RvDOww==", + "version": "3.10.0", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-3.10.0.tgz", + "integrity": "sha512-O2v52ffjLa9VeM43J4XocZE//WT9N0IiwDa3KSHH7Tu8CtH+1qM8SIZvnsTh6v+4yFy5KUY3BHUVwjpfAWsjIA==", "requires": { "argparse": "1.0.9", "esprima": "4.0.0" @@ -1643,9 +1643,9 @@ "integrity": "sha1-plzSkIepJZi4eRJXpSPgISIqwfk=" }, "marked": { - "version": "0.3.6", - "resolved": "https://registry.npmjs.org/marked/-/marked-0.3.6.tgz", - "integrity": "sha1-ssbGGPzOzk74bE/Gy4p8v1rtqNc=" + "version": "0.3.9", + "resolved": "https://registry.npmjs.org/marked/-/marked-0.3.9.tgz", + "integrity": "sha512-nW5u0dxpXxHfkHzzrveY45gCbi+R4PaO4WRZYqZNl+vB0hVGeqlFn0aOg1c8AKL63TrNFn9Bm2UP4AdiZ9TPLw==" }, "merge": { "version": "1.2.0", @@ -2353,18 +2353,23 @@ "integrity": "sha1-hnrHTjhkGHsdPUfZlqeOxciDB3c=" }, "uglify-es": { - "version": "3.0.28", - "resolved": "https://registry.npmjs.org/uglify-es/-/uglify-es-3.0.28.tgz", - "integrity": "sha512-xw1hJsSp361OO0Sq0XvNyTI2wfQ4eKNljfSYyeYX/dz9lKEDj+DK+A8CzB0NmoCwWX1MnEx9f16HlkKXyG65CQ==", + "version": "3.3.4", + "resolved": "https://registry.npmjs.org/uglify-es/-/uglify-es-3.3.4.tgz", + "integrity": "sha512-vDOyDaf7LcABZI5oJt8bin5FD8kYONux5jd8FY6SsV2SfD+MMXaPeGUotysbycSxdu170y5IQ8FvlKzU/TUryw==", "requires": { - "commander": "2.11.0", - "source-map": "0.5.6" + "commander": "2.12.2", + "source-map": "0.6.1" }, "dependencies": { "commander": { - "version": "2.11.0", - "resolved": "https://registry.npmjs.org/commander/-/commander-2.11.0.tgz", - "integrity": "sha512-b0553uYA5YAEGgyYIGYROzKQ7X5RAqedkfjiZxwi0kL1g3bOaBNNZfYkzt/CL0umgD5wc9Jec2FbB98CjkMRvQ==" + "version": "2.12.2", + "resolved": "https://registry.npmjs.org/commander/-/commander-2.12.2.tgz", + "integrity": "sha512-BFnaq5ZOGcDN7FlrtBT4xxkgIToalIIxwjxLWVJ8bGTpe1LroqMiqQXdA7ygc7CRvaYS+9zfPGFnJqFSayx+AA==" + }, + "source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==" } } }, diff --git a/client/package.json b/client/package.json index b7d3605c..c482147e 100644 --- a/client/package.json +++ b/client/package.json @@ -11,20 +11,20 @@ "babelify": "^7.2.0", "browserify": "^13.0.0", "camelcase": "^2.1.1", - "camelcase-keys": "^4.1.0", + "camelcase-keys": "^4.2.0", "csso": "^1.8.0", "font-awesome": "^4.6.1", "glob": "^7.1.2", "html-minifier": "^1.3.1", - "js-cookie": "^2.1.4", - "js-yaml": "^3.9.1", - "marked": "~0.3.2", + "js-cookie": "^2.2.0", + "js-yaml": "^3.10.0", + "marked": "^0.3.9", "merge": "^1.2.0", "mousetrap": "^1.6.1", "nprogress": "^0.2.0", "stylus": "^0.54.2", "superagent": "^1.8.3", - "uglify-es": "^3.0.28", + "uglify-es": "^3.3.4", "underscore": "^1.8.3" } } From a1fbeb91a0984f5daf477a8813dc03979d4ae136 Mon Sep 17 00:00:00 2001 From: rr- Date: Sat, 10 Feb 2018 14:03:43 +0100 Subject: [PATCH 096/159] server/users: fix checking passwords with colons --- server/szurubooru/middleware/authenticator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/szurubooru/middleware/authenticator.py b/server/szurubooru/middleware/authenticator.py index fa492f94..644fe3b3 100644 --- a/server/szurubooru/middleware/authenticator.py +++ b/server/szurubooru/middleware/authenticator.py @@ -24,7 +24,7 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]: 'ValidationError', 'Only basic HTTP authentication is supported.') username, password = base64.decodebytes( - credentials.encode('ascii')).decode('utf8').split(':') + credentials.encode('ascii')).decode('utf8').split(':', 1) return _authenticate(username, password) except ValueError as err: msg = ( From 4b3529272ea238cb35f8f6f12c299e24ce0a88f9 Mon Sep 17 00:00:00 2001 From: ReAnzu Date: Fri, 23 Feb 2018 22:05:58 -0600 Subject: [PATCH 097/159] server/users: let administrators add new users * Added functionality for administrators to directly add users to the application * Added permission users:create:any to handle level that users are allowed to create other users * Moved old permission users:create to users:create:self --- .../controllers/top_navigation_controller.js | 6 ++++-- .../user_registration_controller.js | 20 ++++++++++++++----- config.yaml.dist | 3 ++- server/szurubooru/api/user_api.py | 7 ++++++- .../tests/api/test_user_creating.py | 2 +- 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/client/js/controllers/top_navigation_controller.js b/client/js/controllers/top_navigation_controller.js index b1de03ec..550400cf 100644 --- a/client/js/controllers/top_navigation_controller.js +++ b/client/js/controllers/top_navigation_controller.js @@ -47,10 +47,12 @@ class TopNavigationController { topNavigation.hide('users'); } if (api.isLoggedIn()) { - topNavigation.hide('register'); + if (!api.hasPrivilege('users:create:any')) { + topNavigation.hide('register'); + } topNavigation.hide('login'); } else { - if (!api.hasPrivilege('users:create')) { + if (!api.hasPrivilege('users:create:self')) { topNavigation.hide('register'); } topNavigation.hide('account'); diff --git a/client/js/controllers/user_registration_controller.js b/client/js/controllers/user_registration_controller.js index 7d822380..78b94024 100644 --- a/client/js/controllers/user_registration_controller.js +++ b/client/js/controllers/user_registration_controller.js @@ -10,7 +10,7 @@ const EmptyView = require('../views/empty_view.js'); class UserRegistrationController { constructor() { - if (!api.hasPrivilege('users:create')) { + if (!api.hasPrivilege('users:create:self')) { this._view = new EmptyView(); this._view.showError('Registration is closed.'); return; @@ -29,12 +29,22 @@ class UserRegistrationController { user.name = e.detail.name; user.email = e.detail.email; user.password = e.detail.password; + const isLoggedIn = api.isLoggedIn(); user.save().then(() => { - api.forget(); - return api.login(e.detail.name, e.detail.password, false); + if (isLoggedIn) { + return Promise.resolve(); + } else { + api.forget(); + return api.login(e.detail.name, e.detail.password, false); + } }).then(() => { - const ctx = router.show(uri.formatClientLink()); - ctx.controller.showSuccess('Welcome aboard!'); + if (isLoggedIn) { + const ctx = router.show(uri.formatClientLink('users')); + ctx.controller.showSuccess('User added!'); + } else { + const ctx = router.show(uri.formatClientLink()); + ctx.controller.showSuccess('Welcome aboard!'); + } }, error => { this._view.showError(error.message); this._view.enableForm(); diff --git a/config.yaml.dist b/config.yaml.dist index 42267452..7273ea7b 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -62,7 +62,8 @@ default_rank: regular privileges: - 'users:create': anonymous + 'users:create:self': anonymous # Registration permission + 'users:create:any': administrator 'users:list': regular 'users:view': regular 'users:edit:any:name': moderator diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index e456f22e..5e14fabe 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -26,7 +26,11 @@ def get_users( @rest.routes.post('/users/?') def create_user( ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response: - auth.verify_privilege(ctx.user, 'users:create') + if ctx.user.user_id is None: + auth.verify_privilege(ctx.user, 'users:create:self') + else: + auth.verify_privilege(ctx.user, 'users:create:any') + name = ctx.get_param_as_string('name') password = ctx.get_param_as_string('password') email = ctx.get_param_as_string('email', default='') @@ -40,6 +44,7 @@ def create_user( ctx.get_file('avatar', default=b'')) ctx.session.add(user) ctx.session.commit() + return _serialize(ctx, user, force_show_email=True) diff --git a/server/szurubooru/tests/api/test_user_creating.py b/server/szurubooru/tests/api/test_user_creating.py index b5f36e39..699bfefb 100644 --- a/server/szurubooru/tests/api/test_user_creating.py +++ b/server/szurubooru/tests/api/test_user_creating.py @@ -6,7 +6,7 @@ from szurubooru.func import users @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'users:create': 'regular'}}) + config_injector({'privileges': {'users:create:self': 'regular'}}) def test_creating_user(user_factory, context_factory, fake_datetime): From 4ff8be6a2f0f2156a42790bcb56dee6f53acea92 Mon Sep 17 00:00:00 2001 From: ReAnzu Date: Thu, 8 Mar 2018 00:41:24 -0600 Subject: [PATCH 098/159] server/posts: ignore ffmpeg warnings Poorly formatted MP4 and WEBM sources can cause ffmpeg to throw a lot of warnings. However when there is byte ouptut, the generated thumbnail is valid. Add a bypass for the resize_fill function to allow ffmpeg to error. --- server/szurubooru/func/images.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/server/szurubooru/func/images.py b/server/szurubooru/func/images.py index 16587ed5..9e62aef9 100644 --- a/server/szurubooru/func/images.py +++ b/server/szurubooru/func/images.py @@ -11,10 +11,6 @@ from szurubooru.func import mime, util logger = logging.getLogger(__name__) -_SCALE_FIT_FMT = ( - r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)') - - class Image: def __init__(self, content: bytes) -> None: self.content = content @@ -33,10 +29,14 @@ class Image: return self.info['streams'][0]['nb_read_frames'] def resize_fill(self, width: int, height: int) -> None: + width_greater = self.width > self.height + width, height = (-1, height) if width_greater else (width, -1) + cli = [ '-i', '{path}', '-f', 'image2', - '-vf', _SCALE_FIT_FMT.format(width=width, height=height), + '-filter:v', "scale='{width}:{height}'".format( + width=width, height=height), '-map', '0:v:0', '-vframes', '1', '-vcodec', 'png', @@ -50,7 +50,7 @@ class Image: '-ss', '%d' % math.floor(duration * 0.3), ] + cli - content = self._execute(cli) + content = self._execute(cli, ignore_error_if_data=True) if not content: raise errors.ProcessingError('Error while resizing image.') self.content = content @@ -79,7 +79,11 @@ class Image: '-', ]) - def _execute(self, cli: List[str], program: str = 'ffmpeg') -> bytes: + def _execute( + self, + cli: List[str], + program: str = 'ffmpeg', + ignore_error_if_data: bool = False) -> bytes: extension = mime.get_extension(mime.get_mime_type(self.content)) assert extension with util.create_temp_file(suffix='.' + extension) as handle: @@ -98,8 +102,11 @@ class Image: 'Failed to execute ffmpeg command (cli=%r, err=%r)', ' '.join(shlex.quote(arg) for arg in cli), err) - raise errors.ProcessingError( - 'Error while processing image.\n' + err.decode('utf-8')) + if ((len(out) > 0 and not ignore_error_if_data) + or len(out) == 0): + raise errors.ProcessingError( + 'Error while processing image.\n' + + err.decode('utf-8')) return out def _reload_info(self) -> None: From 12ec43f09887ca1682c6d02557336f3f45c7ba30 Mon Sep 17 00:00:00 2001 From: ReAnzu Date: Thu, 8 Mar 2018 00:47:58 -0600 Subject: [PATCH 099/159] server/posts: auto convert GIFs to WEBMs/MP4s - Default setting is false for both conversions, as this will require additional resources of the server, but is bandwidth friendly for viewers - WEBM conversion is slow, but better quality than MP4 conversion with a typically smaller file size - Tags are copied over from the original upload - Snapshots are generated for the new auto posts --- config.yaml.dist | 6 +++ server/szurubooru/api/post_api.py | 21 +++++++++-- server/szurubooru/func/images.py | 62 +++++++++++++++++++++++++++++++ server/szurubooru/func/posts.py | 45 +++++++++++++++++++++- server/szurubooru/func/util.py | 10 +++++ 5 files changed, 138 insertions(+), 6 deletions(-) diff --git a/config.yaml.dist b/config.yaml.dist index 7273ea7b..79e17230 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -27,6 +27,12 @@ thumbnails: post_height: 300 +convert: + gif: + to_webm: false + to_mp4: false + + # used to send password reset e-mails smtp: host: # example: localhost diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 27f10c16..aaa7437f 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Optional, Dict, List from datetime import datetime from szurubooru import db, model, errors, rest, search from szurubooru.func import ( @@ -69,13 +69,26 @@ def create_post( posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) ctx.session.add(post) ctx.session.flush() - snapshots.create(post, None if anonymous else ctx.user) - for tag in new_tags: - snapshots.create(tag, None if anonymous else ctx.user) + create_snapshots_for_post(post, new_tags, None if anonymous else ctx.user) + alternate_format_posts = posts.generate_alternate_formats(post, content) + for alternate_post, alternate_post_new_tags in alternate_format_posts: + create_snapshots_for_post( + alternate_post, + alternate_post_new_tags, + None if anonymous else ctx.user) ctx.session.commit() return _serialize_post(ctx, post) +def create_snapshots_for_post( + post: model.Post, + new_tags: List[model.Tag], + user: Optional[model.User]): + snapshots.create(post, user) + for tag in new_tags: + snapshots.create(tag, user) + + @rest.routes.get('/post/(?P[^/]+)/?') def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:view') diff --git a/server/szurubooru/func/images.py b/server/szurubooru/func/images.py index 9e62aef9..b3df55e7 100644 --- a/server/szurubooru/func/images.py +++ b/server/szurubooru/func/images.py @@ -79,6 +79,68 @@ class Image: '-', ]) + def to_webm(self) -> bytes: + with util.create_temp_file_path(suffix='.log') as phase_log_path: + # Pass 1 + self._execute([ + '-i', '{path}', + '-pass', '1', + '-passlogfile', phase_log_path, + '-vcodec', 'libvpx-vp9', + '-crf', '4', + '-b:v', '2500K', + '-acodec', 'libvorbis', + '-f', 'webm', + '-y', '/dev/null' + ]) + + # Pass 2 + return self._execute([ + '-i', '{path}', + '-pass', '2', + '-passlogfile', phase_log_path, + '-vcodec', 'libvpx-vp9', + '-crf', '4', + '-b:v', '2500K', + '-acodec', 'libvorbis', + '-f', 'webm', + '-' + ]) + + def to_mp4(self) -> bytes: + with util.create_temp_file_path(suffix='.dat') as mp4_temp_path: + width = self.width + height = self.height + altered_dimensions = False + + if self.width % 2 != 0: + width = self.width - 1 + altered_dimensions = True + + if self.height % 2 != 0: + height = self.height - 1 + altered_dimensions = True + + args = [ + '-i', '{path}', + '-vcodec', 'libx264', + '-preset', 'slow', + '-crf', '22', + '-b:v', '200K', + '-profile:v', 'main', + '-pix_fmt', 'yuv420p', + '-acodec', 'aac', + '-f', 'mp4' + ] + + if altered_dimensions: + args += ['-filter:v', 'scale=\'%d:%d\'' % (width, height)] + + self._execute(args + ['-y', mp4_temp_path]) + + with open(mp4_temp_path, 'rb') as mp4_temp: + return mp4_temp.read() + def _execute( self, cli: List[str], diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 406f6e18..219c832e 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -5,7 +5,7 @@ import sqlalchemy as sa from szurubooru import config, db, model, errors, rest from szurubooru.func import ( users, scores, comments, tags, util, - mime, images, files, image_hash, serialization) + mime, images, files, image_hash, serialization, snapshots) EMPTY_PIXEL = ( @@ -364,7 +364,7 @@ def create_post( update_post_content(post, content) new_tags = update_post_tags(post, tag_names) - return (post, new_tags) + return post, new_tags def update_post_safety(post: model.Post, safety: str) -> None: @@ -429,6 +429,47 @@ def _sync_post_content(post: model.Post) -> None: generate_post_thumbnail(post) +def generate_alternate_formats(post: model.Post, content: bytes) \ + -> List[Tuple[model.Post, List[model.Tag]]]: + assert post + assert content + new_posts = [] + if mime.is_animated_gif(content): + tag_names = [ + tag_name.name + for tag_name in [tag.names for tag in post.tags]] + + if config.config['convert']['gif']['to_mp4']: + mp4_post, new_tags = create_post( + images.Image(content).to_mp4(), + tag_names, + post.user) + update_post_flags(mp4_post, ['loop']) + update_post_safety(mp4_post, post.safety) + update_post_source(mp4_post, post.source) + new_posts += [(mp4_post, new_tags)] + + if config.config['convert']['gif']['to_webm']: + webm_post, new_tags = create_post( + images.Image(content).to_webm(), + tag_names, + post.user) + update_post_flags(webm_post, ['loop']) + update_post_safety(webm_post, post.safety) + update_post_source(webm_post, post.source) + new_posts += [(webm_post, new_tags)] + + db.session.flush() + + new_posts = [p for p in new_posts if p[0] is not None] + + new_relations = [p[0].post_id for p in new_posts] + if len(new_relations) > 0: + update_post_relations(post, new_relations) + + return new_posts + + def update_post_content(post: model.Post, content: Optional[bytes]) -> None: assert post if not content: diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index 61e975b6..ba2d4dc9 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -41,6 +41,16 @@ def create_temp_file(**kwargs: Any) -> Generator: os.remove(path) +@contextmanager +def create_temp_file_path(**kwargs: Any) -> Generator: + (descriptor, path) = tempfile.mkstemp(**kwargs) + os.close(descriptor) + try: + yield path + finally: + os.remove(path) + + def unalias_dict(source: List[Tuple[List[str], T]]) -> Dict[str, T]: output_dict = {} # type: Dict[str, T] for aliases, value in source: From 7519e071e79a2cf51e41e35fce030749c71fc0cf Mon Sep 17 00:00:00 2001 From: ReAnzu Date: Sat, 24 Feb 2018 01:57:31 -0600 Subject: [PATCH 100/159] server/posts: deleting a post purges its artifacts Specifically, its thumbnail and post source. --- config.yaml.dist | 6 ++ server/szurubooru/func/posts.py | 13 ++-- .../tests/api/test_post_deleting.py | 11 +++- server/szurubooru/tests/func/test_posts.py | 61 +++++++++++++------ server/szurubooru/tests/model/test_post.py | 10 +++ server/szurubooru/tests/model/test_tag.py | 10 +++ 6 files changed, 87 insertions(+), 24 deletions(-) diff --git a/config.yaml.dist b/config.yaml.dist index 79e17230..a8b0a1ff 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -20,6 +20,12 @@ database: test_database: 'sqlite:///:memory:' # required for running the test suite +# Delete thumbnails and source files on post delete +# Original functionality is no, to mitigate the impacts of admins going +# on unchecked post purges. +delete_source_files: no + + thumbnails: avatar_width: 300 avatar_height: 300 diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 219c832e..9589a30f 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -400,6 +400,9 @@ def _before_post_delete( _mapper: Any, _connection: Any, post: model.Post) -> None: if post.post_id: image_hash.delete_image(post.post_id) + if config.config['delete_source_files']: + files.delete(get_post_content_path(post)) + files.delete(get_post_thumbnail_path(post)) def _sync_post_content(post: model.Post) -> None: @@ -727,12 +730,14 @@ def merge_posts( merge_favorites(source_post.post_id, target_post.post_id) merge_relations(source_post.post_id, target_post.post_id) - delete(source_post) - - db.session.flush() - + content = None if replace_content: content = files.get(get_post_content_path(source_post)) + + delete(source_post) + db.session.flush() + + if content is not None: update_post_content(target_post, content) diff --git a/server/szurubooru/tests/api/test_post_deleting.py b/server/szurubooru/tests/api/test_post_deleting.py index e35a4488..bb5f9ced 100644 --- a/server/szurubooru/tests/api/test_post_deleting.py +++ b/server/szurubooru/tests/api/test_post_deleting.py @@ -1,12 +1,19 @@ from unittest.mock import patch import pytest from szurubooru import api, db, model, errors -from szurubooru.func import posts, tags, snapshots +from szurubooru.func import posts, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'posts:delete': model.User.RANK_REGULAR}}) + config_injector({ + 'secret': 'secret', + 'data_dir': '', + 'delete_source_files': False, + 'privileges': { + 'posts:delete': model.User.RANK_REGULAR + } + }) def test_deleting(user_factory, post_factory, context_factory): diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index e296ec0a..d0c27ba6 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -1,6 +1,6 @@ -import os -from unittest.mock import patch from datetime import datetime +from unittest.mock import patch +import os import pytest from szurubooru import db, model from szurubooru.func import ( @@ -675,7 +675,8 @@ def test_feature_post(post_factory, user_factory): assert new_featured_post == post -def test_delete(post_factory): +def test_delete(post_factory, config_injector): + config_injector({'delete_source_files': False}) post = post_factory() db.session.add(post) db.session.flush() @@ -685,7 +686,8 @@ def test_delete(post_factory): assert posts.get_post_count() == 0 -def test_merge_posts_deletes_source_post(post_factory): +def test_merge_posts_deletes_source_post(post_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() db.session.add_all([source_post, target_post]) @@ -697,7 +699,8 @@ def test_merge_posts_deletes_source_post(post_factory): assert post is not None -def test_merge_posts_with_itself(post_factory): +def test_merge_posts_with_itself(post_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() db.session.add(source_post) db.session.flush() @@ -705,7 +708,8 @@ def test_merge_posts_with_itself(post_factory): posts.merge_posts(source_post, source_post, False) -def test_merge_posts_moves_tags(post_factory, tag_factory): +def test_merge_posts_moves_tags(post_factory, tag_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() tag = tag_factory() @@ -720,7 +724,9 @@ def test_merge_posts_moves_tags(post_factory, tag_factory): assert posts.get_post_by_id(target_post.post_id).tag_count == 1 -def test_merge_posts_doesnt_duplicate_tags(post_factory, tag_factory): +def test_merge_posts_doesnt_duplicate_tags( + post_factory, tag_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() tag = tag_factory() @@ -735,7 +741,9 @@ def test_merge_posts_doesnt_duplicate_tags(post_factory, tag_factory): assert posts.get_post_by_id(target_post.post_id).tag_count == 1 -def test_merge_posts_moves_comments(post_factory, comment_factory): +def test_merge_posts_moves_comments( + post_factory, comment_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() comment = comment_factory(post=source_post) @@ -749,7 +757,9 @@ def test_merge_posts_moves_comments(post_factory, comment_factory): assert posts.get_post_by_id(target_post.post_id).comment_count == 1 -def test_merge_posts_moves_scores(post_factory, post_score_factory): +def test_merge_posts_moves_scores( + post_factory, post_score_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() score = post_score_factory(post=source_post, score=1) @@ -764,7 +774,8 @@ def test_merge_posts_moves_scores(post_factory, post_score_factory): def test_merge_posts_doesnt_duplicate_scores( - post_factory, user_factory, post_score_factory): + post_factory, user_factory, post_score_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() user = user_factory() @@ -780,7 +791,9 @@ def test_merge_posts_doesnt_duplicate_scores( assert posts.get_post_by_id(target_post.post_id).score == 1 -def test_merge_posts_moves_favorites(post_factory, post_favorite_factory): +def test_merge_posts_moves_favorites( + post_factory, post_favorite_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() favorite = post_favorite_factory(post=source_post) @@ -795,7 +808,8 @@ def test_merge_posts_moves_favorites(post_factory, post_favorite_factory): def test_merge_posts_doesnt_duplicate_favorites( - post_factory, user_factory, post_favorite_factory): + post_factory, user_factory, post_favorite_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() user = user_factory() @@ -811,7 +825,8 @@ def test_merge_posts_doesnt_duplicate_favorites( assert posts.get_post_by_id(target_post.post_id).favorite_count == 1 -def test_merge_posts_moves_child_relations(post_factory): +def test_merge_posts_moves_child_relations(post_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() related_post = post_factory() @@ -826,7 +841,9 @@ def test_merge_posts_moves_child_relations(post_factory): assert posts.get_post_by_id(target_post.post_id).relation_count == 1 -def test_merge_posts_doesnt_duplicate_child_relations(post_factory): +def test_merge_posts_doesnt_duplicate_child_relations( + post_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() related_post = post_factory() @@ -842,7 +859,8 @@ def test_merge_posts_doesnt_duplicate_child_relations(post_factory): assert posts.get_post_by_id(target_post.post_id).relation_count == 1 -def test_merge_posts_moves_parent_relations(post_factory): +def test_merge_posts_moves_parent_relations(post_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() related_post = post_factory() @@ -859,7 +877,9 @@ def test_merge_posts_moves_parent_relations(post_factory): assert posts.get_post_by_id(related_post.post_id).relation_count == 1 -def test_merge_posts_doesnt_duplicate_parent_relations(post_factory): +def test_merge_posts_doesnt_duplicate_parent_relations( + post_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() related_post = post_factory() @@ -876,7 +896,9 @@ def test_merge_posts_doesnt_duplicate_parent_relations(post_factory): assert posts.get_post_by_id(related_post.post_id).relation_count == 1 -def test_merge_posts_doesnt_create_relation_loop_for_children(post_factory): +def test_merge_posts_doesnt_create_relation_loop_for_children( + post_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() source_post.relations = [target_post] @@ -890,7 +912,9 @@ def test_merge_posts_doesnt_create_relation_loop_for_children(post_factory): assert posts.get_post_by_id(target_post.post_id).relation_count == 0 -def test_merge_posts_doesnt_create_relation_loop_for_parents(post_factory): +def test_merge_posts_doesnt_create_relation_loop_for_parents( + post_factory, config_injector): + config_injector({'delete_source_files': False}) source_post = post_factory() target_post = post_factory() target_post.relations = [source_post] @@ -909,6 +933,7 @@ def test_merge_posts_replaces_content( config_injector({ 'data_dir': str(tmpdir.mkdir('data')), 'data_url': 'example.com', + 'delete_source_files': False, 'thumbnails': { 'post_width': 300, 'post_height': 300, diff --git a/server/szurubooru/tests/model/test_post.py b/server/szurubooru/tests/model/test_post.py index f35e2751..ee691460 100644 --- a/server/szurubooru/tests/model/test_post.py +++ b/server/szurubooru/tests/model/test_post.py @@ -1,7 +1,17 @@ from datetime import datetime +import pytest from szurubooru import db, model +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({ + 'secret': 'secret', + 'data_dir': '', + 'delete_source_files': False + }) + + def test_saving_post(post_factory, user_factory, tag_factory): user = user_factory() tag1 = tag_factory() diff --git a/server/szurubooru/tests/model/test_tag.py b/server/szurubooru/tests/model/test_tag.py index 07bbc0e5..b677eeff 100644 --- a/server/szurubooru/tests/model/test_tag.py +++ b/server/szurubooru/tests/model/test_tag.py @@ -1,7 +1,17 @@ from datetime import datetime +import pytest from szurubooru import db, model +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({ + 'delete_source_files': False, + 'secret': 'secret', + 'data_dir': '' + }) + + def test_saving_tag(tag_factory): sug1 = tag_factory(names=['sug1']) sug2 = tag_factory(names=['sug2']) From 3f52aceca44bc52e0c8654f46c66073320109ae9 Mon Sep 17 00:00:00 2001 From: ReAnzu Date: Sat, 24 Feb 2018 23:45:00 -0600 Subject: [PATCH 101/159] server/users: harden password hashes - Changed password setup to use libsodium and argon2id (regular SHA256 hashing for passwords is inadequate as modern GPU's can hash generate billions of hashes per second). - Added code to auto migrate old passwords to the new password_hash if the existing password_hash matches either of the legacy password generation schemes (SHA1 or SHA256). - Added migration to support new password_hash format length - Added column password_revision. This field will default to 0, which all passwords will have till they're updated. After that each password hash method has a revision. --- server/requirements.txt | 1 + server/szurubooru/func/auth.py | 48 +++++++--- server/szurubooru/func/users.py | 10 ++- server/szurubooru/migrations/env.py | 8 +- ...pdate_user_table_for_hardened_passwords.py | 89 +++++++++++++++++++ server/szurubooru/model/user.py | 4 +- server/szurubooru/tests/conftest.py | 11 ++- server/szurubooru/tests/func/test_auth.py | 43 +++++++++ server/szurubooru/tests/func/test_users.py | 6 +- 9 files changed, 198 insertions(+), 22 deletions(-) create mode 100644 server/szurubooru/migrations/versions/9ef1a1643c2a_update_user_table_for_hardened_passwords.py create mode 100644 server/szurubooru/tests/func/test_auth.py diff --git a/server/requirements.txt b/server/requirements.txt index 2cd15ec1..7cc47868 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -11,3 +11,4 @@ scipy>=0.18.1 elasticsearch>=5.0.0 elasticsearch-dsl>=5.0.0 scikit-image>=0.12 +pynacl>=1.2.1 \ No newline at end of file diff --git a/server/szurubooru/func/auth.py b/server/szurubooru/func/auth.py index 25c991c4..c9740fe0 100644 --- a/server/szurubooru/func/auth.py +++ b/server/szurubooru/func/auth.py @@ -1,7 +1,10 @@ +from typing import Tuple import hashlib import random from collections import OrderedDict -from szurubooru import config, model, errors +from nacl import pwhash +from nacl.exceptions import InvalidkeyError +from szurubooru import config, model, errors, db from szurubooru.func import util @@ -16,22 +19,29 @@ RANK_MAP = OrderedDict([ ]) -def get_password_hash(salt: str, password: str) -> str: - ''' Retrieve new-style password hash. ''' +def get_password_hash(salt: str, password: str) -> Tuple[str, int]: + ''' Retrieve argon2id password hash. ''' + return pwhash.argon2id.str( + (config.config['secret'] + salt + password).encode('utf8') + ).decode('utf8'), 3 + + +def get_sha256_legacy_password_hash(salt: str, password: str) -> Tuple[str, int]: + ''' Retrieve old-style sha256 password hash. ''' digest = hashlib.sha256() digest.update(config.config['secret'].encode('utf8')) digest.update(salt.encode('utf8')) digest.update(password.encode('utf8')) - return digest.hexdigest() + return digest.hexdigest(), 2 -def get_legacy_password_hash(salt: str, password: str) -> str: - ''' Retrieve old-style password hash. ''' +def get_sha1_legacy_password_hash(salt: str, password: str) -> Tuple[str, int]: + ''' Retrieve old-style sha1 password hash. ''' digest = hashlib.sha1() digest.update(b'1A2/$_4xVa') digest.update(salt.encode('utf8')) digest.update(password.encode('utf8')) - return digest.hexdigest() + return digest.hexdigest(), 1 def create_password() -> str: @@ -47,11 +57,25 @@ def create_password() -> str: def is_valid_password(user: model.User, password: str) -> bool: assert user salt, valid_hash = user.password_salt, user.password_hash - possible_hashes = [ - get_password_hash(salt, password), - get_legacy_password_hash(salt, password) - ] - return valid_hash in possible_hashes + + try: + return pwhash.verify( + user.password_hash.encode('utf8'), + (config.config['secret'] + salt + password).encode('utf8')) + except InvalidkeyError: + possible_hashes = [ + get_sha256_legacy_password_hash(salt, password)[0], + get_sha1_legacy_password_hash(salt, password)[0] + ] + if valid_hash in possible_hashes: + # Convert the user password hash to the new hash + new_hash, revision = get_password_hash(salt, password) + user.password_hash = new_hash + user.password_revision = revision + db.session.commit() + return True + + return False def has_privilege(user: model.User, privilege_name: str) -> bool: diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index ba6f67f2..012debca 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -243,7 +243,10 @@ def update_user_password(user: model.User, password: str) -> None: raise InvalidPasswordError( 'Password must satisfy regex %r.' % password_regex) user.password_salt = auth.create_password() - user.password_hash = auth.get_password_hash(user.password_salt, password) + password_hash, revision = auth.get_password_hash( + user.password_salt, password) + user.password_hash = password_hash + user.password_revision = revision def update_user_email(user: model.User, email: str) -> None: @@ -308,5 +311,8 @@ def reset_user_password(user: model.User) -> str: assert user password = auth.create_password() user.password_salt = auth.create_password() - user.password_hash = auth.get_password_hash(user.password_salt, password) + password_hash, revision = auth.get_password_hash( + user.password_salt, password) + user.password_hash = password_hash + user.password_revision = revision return password diff --git a/server/szurubooru/migrations/env.py b/server/szurubooru/migrations/env.py index 7065a69e..59d031f6 100644 --- a/server/szurubooru/migrations/env.py +++ b/server/szurubooru/migrations/env.py @@ -35,7 +35,10 @@ def run_migrations_offline(): ''' url = alembic_config.get_main_option('sqlalchemy.url') alembic.context.configure( - url=url, target_metadata=target_metadata, literal_binds=True) + url=url, + target_metadata=target_metadata, + literal_binds=True, + compare_type=True) with alembic.context.begin_transaction(): alembic.context.run_migrations() @@ -56,7 +59,8 @@ def run_migrations_online(): with connectable.connect() as connection: alembic.context.configure( connection=connection, - target_metadata=target_metadata) + target_metadata=target_metadata, + compare_type=True) with alembic.context.begin_transaction(): alembic.context.run_migrations() diff --git a/server/szurubooru/migrations/versions/9ef1a1643c2a_update_user_table_for_hardened_passwords.py b/server/szurubooru/migrations/versions/9ef1a1643c2a_update_user_table_for_hardened_passwords.py new file mode 100644 index 00000000..38057728 --- /dev/null +++ b/server/szurubooru/migrations/versions/9ef1a1643c2a_update_user_table_for_hardened_passwords.py @@ -0,0 +1,89 @@ +''' +Alter the password_hash field to work with larger output. +Particularly libsodium output for greater password security. + +Revision ID: 9ef1a1643c2a +Created at: 2018-02-24 23:00:32.848575 +''' + +import sqlalchemy as sa +import sqlalchemy.ext.declarative +import sqlalchemy.orm.session +from alembic import op + + +revision = '9ef1a1643c2a' +down_revision = '02ef5f73f4ab' +branch_labels = None +depends_on = None + +Base = sa.ext.declarative.declarative_base() + + +class User(Base): + __tablename__ = 'user' + + AVATAR_GRAVATAR = 'gravatar' + + user_id = sa.Column('id', sa.Integer, primary_key=True) + creation_time = sa.Column('creation_time', sa.DateTime, nullable=False) + last_login_time = sa.Column('last_login_time', sa.DateTime) + version = sa.Column('version', sa.Integer, default=1, nullable=False) + name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True) + password_hash = sa.Column('password_hash', sa.Unicode(128), nullable=False) + password_salt = sa.Column('password_salt', sa.Unicode(32)) + password_revision = sa.Column( + 'password_revision', sa.SmallInteger, default=0, nullable=False) + email = sa.Column('email', sa.Unicode(64), nullable=True) + rank = sa.Column('rank', sa.Unicode(32), nullable=False) + avatar_style = sa.Column( + 'avatar_style', sa.Unicode(32), nullable=False, + default=AVATAR_GRAVATAR) + + __mapper_args__ = { + 'version_id_col': version, + 'version_id_generator': False, + } + + +def upgrade(): + op.alter_column( + 'user', + 'password_hash', + existing_type=sa.VARCHAR(length=64), + type_=sa.Unicode(length=128), + existing_nullable=False) + op.add_column('user', sa.Column( + 'password_revision', + sa.SmallInteger(), + nullable=True, + default=0)) + + session = sa.orm.session.Session(bind=op.get_bind()) + if session.query(User).count() >= 0: + for user in session.query(User).all(): + password_hash_length = len(user.password_hash) + if password_hash_length == 40: + user.password_revision = 1 + elif password_hash_length == 64: + user.password_revision = 2 + else: + user.password_revision = 0 + session.flush() + session.commit() + + op.alter_column( + 'user', + 'password_revision', + existing_nullable=True, + nullable=False) + + +def downgrade(): + op.alter_column( + 'user', + 'password_hash', + existing_type=sa.Unicode(length=128), + type_=sa.VARCHAR(length=64), + existing_nullable=False) + op.drop_column('user', 'password_revision') diff --git a/server/szurubooru/model/user.py b/server/szurubooru/model/user.py index dd7c0629..39c5a91b 100644 --- a/server/szurubooru/model/user.py +++ b/server/szurubooru/model/user.py @@ -23,8 +23,10 @@ class User(Base): last_login_time = sa.Column('last_login_time', sa.DateTime) version = sa.Column('version', sa.Integer, default=1, nullable=False) name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True) - password_hash = sa.Column('password_hash', sa.Unicode(64), nullable=False) + password_hash = sa.Column('password_hash', sa.Unicode(128), nullable=False) password_salt = sa.Column('password_salt', sa.Unicode(32)) + password_revision = sa.Column( + 'password_revision', sa.SmallInteger, default=0, nullable=False) email = sa.Column('email', sa.Unicode(64), nullable=True) rank = sa.Column('rank', sa.Unicode(32), nullable=False) avatar_style = sa.Column( diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index e71f9609..db7806e8 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -115,11 +115,16 @@ def config_injector(): @pytest.fixture def user_factory(): - def factory(name=None, rank=model.User.RANK_REGULAR, email='dummy'): + def factory( + name=None, + rank=model.User.RANK_REGULAR, + email='dummy', + password_salt=None, + password_hash=None): user = model.User() user.name = name or get_unique_name() - user.password_salt = 'dummy' - user.password_hash = 'dummy' + user.password_salt = password_salt or 'dummy' + user.password_hash = password_hash or 'dummy' user.email = email user.rank = rank user.creation_time = datetime(1997, 1, 1) diff --git a/server/szurubooru/tests/func/test_auth.py b/server/szurubooru/tests/func/test_auth.py new file mode 100644 index 00000000..5d0955ad --- /dev/null +++ b/server/szurubooru/tests/func/test_auth.py @@ -0,0 +1,43 @@ +from szurubooru.func import auth +import pytest + + +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'secret': 'testSecret'}) + + +def test_get_password_hash(): + salt, password = ('testSalt', 'pass') + result, revision = auth.get_password_hash(salt, password) + assert result + assert revision == 3 + hash_parts = list( + filter(lambda e: e is not None and e != '', result.split('$'))) + assert len(hash_parts) == 5 + assert hash_parts[0] == 'argon2id' + + +def test_get_sha256_legacy_password_hash(): + salt, password = ('testSalt', 'pass') + result, revision = auth.get_sha256_legacy_password_hash(salt, password) + hash = '2031ac9631353ac9303719a7f808a24f79aa1d71712c98523e4bb4cce579428a' + assert result == hash + assert revision == 2 + + +def test_get_sha1_legacy_password_hash(): + salt, password = ('testSalt', 'pass') + result, revision = auth.get_sha1_legacy_password_hash(salt, password) + assert result == '1eb1f953d9be303a1b54627e903e6124cfb1245b' + assert revision == 1 + + +def test_is_valid_password_auto_upgrades_user_password_hash(user_factory): + salt, password = ('testSalt', 'pass') + hash, revision = auth.get_sha256_legacy_password_hash(salt, password) + user = user_factory(password_salt=salt, password_hash=hash) + result = auth.is_valid_password(user, password) + assert result is True + assert user.password_hash != hash + assert user.password_revision > revision diff --git a/server/szurubooru/tests/func/test_users.py b/server/szurubooru/tests/func/test_users.py index 76c52c8b..55061276 100644 --- a/server/szurubooru/tests/func/test_users.py +++ b/server/szurubooru/tests/func/test_users.py @@ -320,10 +320,11 @@ def test_update_user_password(user_factory, config_injector): with patch('szurubooru.func.auth.create_password'), \ patch('szurubooru.func.auth.get_password_hash'): auth.create_password.return_value = 'salt' - auth.get_password_hash.return_value = 'hash' + auth.get_password_hash.return_value = ('hash', 3) users.update_user_password(user, 'a') assert user.password_salt == 'salt' assert user.password_hash == 'hash' + assert user.password_revision == 3 def test_update_user_email_with_too_long_string(user_factory): @@ -447,7 +448,8 @@ def test_reset_user_password(user_factory): patch('szurubooru.func.auth.get_password_hash'): user = user_factory() auth.create_password.return_value = 'salt' - auth.get_password_hash.return_value = 'hash' + auth.get_password_hash.return_value = ('hash', 3) users.reset_user_password(user) assert user.password_salt == 'salt' assert user.password_hash == 'hash' + assert user.password_revision == 3 From c770ad8f28064574fcbcdeea44e895b911acd8f3 Mon Sep 17 00:00:00 2001 From: ReAnzu Date: Fri, 9 Mar 2018 00:19:17 -0600 Subject: [PATCH 102/159] client/posts: fix copy tags list of string values error #153 --- client/js/models/post.js | 1 + client/js/views/post_upload_view.js | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/client/js/models/post.js b/client/js/models/post.js index 2b63bd62..c5750663 100644 --- a/client/js/models/post.js +++ b/client/js/models/post.js @@ -39,6 +39,7 @@ class Post extends events.EventTarget { get flags() { return this._flags; } get tags() { return this._tags; } + get tagNames() { return this._tags.map(tag => tag.names[0]); } get notes() { return this._notes; } get comments() { return this._comments; } get relations() { return this._relations; } diff --git a/client/js/views/post_upload_view.js b/client/js/views/post_upload_view.js index 32b10bd1..b64fe260 100644 --- a/client/js/views/post_upload_view.js +++ b/client/js/views/post_upload_view.js @@ -297,7 +297,7 @@ class PostUploadView extends events.EventTarget { let lookalikeNode = rowNode.querySelector( `.lookalikes li:nth-child(${i + 1})`); if (lookalikeNode.querySelector('[name=copy-tags]').checked) { - uploadable.tags = uploadable.tags.concat(lookalike.post.tags); + uploadable.tags = uploadable.tags.concat(lookalike.post.tagNames); } if (lookalikeNode.querySelector('[name=add-relation]').checked) { uploadable.relations.push(lookalike.post.id); From 26a1451ff62f10cb9f854fe791490668bb88dc28 Mon Sep 17 00:00:00 2001 From: Michael Serajnik Date: Sat, 16 Dec 2017 03:45:51 +0100 Subject: [PATCH 103/159] client/css: improve mobile styling --- client/css/comment-control.styl | 1 - client/css/comment-list-control.styl | 5 +++ client/css/comment-list-view.styl | 18 +++++++--- client/css/core-forms.styl | 16 +++++++-- client/css/core-general.styl | 36 +++++++++++++++++++ client/css/home-view.styl | 3 ++ client/css/pager.styl | 2 +- client/css/post-list-view.styl | 25 +++++++++++++ client/css/post-main-view.styl | 3 ++ client/css/snapshots-list-view.styl | 8 +++-- client/css/tag-categories-view.styl | 6 ++++ client/css/tag-list-view.styl | 4 +++ client/css/user-list-view.styl | 3 ++ client/html/tag_categories.tpl | 24 +++++++------ client/html/tags_page.tpl | 2 +- client/html/top_navigation.tpl | 4 +++ .../controllers/top_navigation_controller.js | 2 ++ client/js/views/top_navigation_view.js | 32 +++++++++++++++++ 18 files changed, 170 insertions(+), 24 deletions(-) diff --git a/client/css/comment-control.styl b/client/css/comment-control.styl index 6b730172..a54dacc0 100644 --- a/client/css/comment-control.styl +++ b/client/css/comment-control.styl @@ -3,7 +3,6 @@ $comment-header-background-color = $top-navigation-color $comment-border-color = #DDD .comment-container - margin: 0 0 1em 0 padding: 0 0 0 60px .avatar diff --git a/client/css/comment-list-control.styl b/client/css/comment-list-control.styl index 63459d64..fe847245 100644 --- a/client/css/comment-list-control.styl +++ b/client/css/comment-list-control.styl @@ -2,3 +2,8 @@ list-style-type: none margin: 0 padding: 0 + + >li + margin-bottom: 1em + &:last-child + margin-bottom: 0 diff --git a/client/css/comment-list-view.styl b/client/css/comment-list-view.styl index 697b372b..bd50beb8 100644 --- a/client/css/comment-list-view.styl +++ b/client/css/comment-list-view.styl @@ -1,15 +1,24 @@ +@import colors +$comment-border-color = $top-navigation-color + .global-comment-list text-align: left &>ul list-style-type: none - margin: 1em 0 + margin: 1em 0 0 padding: 0 + &>li + margin-top: 2em + padding-top: 2em + border-top: 3px solid $comment-border-color + &:first-child + margin-top: 0 + padding-top: 0 + border-top: none + @media (max-width: 700px) - &>li - margin-bottom: 5em - padding: 1vw .post-thumbnail margin-bottom: 1em .thumbnail @@ -19,7 +28,6 @@ @media (min-width: 700px) &>li padding-left: 13em - margin-bottom: 2em .post-thumbnail float: left margin: 0 0 1em -13em diff --git a/client/css/core-forms.styl b/client/css/core-forms.styl index a163718b..f083ffbe 100644 --- a/client/css/core-forms.styl +++ b/client/css/core-forms.styl @@ -31,13 +31,22 @@ form.horizontal margin-bottom: 1em .input, .buttons, ul display: inline-block - vertical-align: middle + vertical-align: top margin: 0 padding: 0 input - vertical-align: middle + vertical-align: top .buttons margin-right: 0.5em + @media (max-width: 1000px) + display: block + .input, .buttons, ul + display: block + margin-top: 0.5em + &:first-child + margin-top: 0 + .buttons + margin-right: 0 @@ -213,10 +222,13 @@ input[type=submit] cursor: pointer font-size: 100% padding: 0.2em 0.7em + border-radius: 0 border: 2px solid $button-enabled-background-color background: $button-enabled-background-color color: $button-enabled-text-color outline: 0 /* something on Chrome */ + -moz-appearance: none + -webkit-appearance: none &:disabled cursor: default diff --git a/client/css/core-general.styl b/client/css/core-general.styl index 9209bc30..3892a825 100644 --- a/client/css/core-general.styl +++ b/client/css/core-general.styl @@ -125,6 +125,37 @@ nav li display: inline-block float: left + #mobile-navigation-toggle + display: none + width: 100% + padding: 0 1.5vw + line-height: 2.3em + font-family: inherit + border: none + background: none + color: $active-tab-text-color + .site-name + display: block + float: left + max-width: 50vw + overflow: hidden + text-overflow: ellipsis + .toggle-icon + display: block + float: right + @media (max-width: 1000px) + text-align: left + li + display: none + float: none + a + display: block + padding: 0 1.5vw + #mobile-navigation-toggle + display: block + &.opened + li + display: block ul li[data-name=account], ul li[data-name=register], ul li[data-name=login], @@ -141,6 +172,8 @@ nav margin-right: 0.6em margin-left: calc(0.6em - 1.2em) float: left + @media (max-width: 1000px) + display: none a .access-key text-decoration: underline @@ -194,6 +227,9 @@ a .access-key margin-top: 0 !important margin-bottom: 0 !important +.table-wrap + overflow-x: scroll + /* hack to prevent text from being copied */ [data-pseudo-content]:before { content: attr(data-pseudo-content) diff --git a/client/css/home-view.styl b/client/css/home-view.styl index 06292aea..00e46d70 100644 --- a/client/css/home-view.styl +++ b/client/css/home-view.styl @@ -15,6 +15,7 @@ margin: 0 auto 2em auto form + display: inline-block width: auto vertical-align: middle margin: 0 0 2em 0 @@ -52,6 +53,8 @@ li display: inline white-space: nowrap + @media (max-width: 800px) + display: block .sep word-spacing: 1.1em background-repeat: no-repeat diff --git a/client/css/pager.styl b/client/css/pager.styl index df35f3e8..3976404c 100644 --- a/client/css/pager.styl +++ b/client/css/pager.styl @@ -8,7 +8,7 @@ .page position: relative .page-header - margin: 0.5em 0.5em 0.5em 0 + margin: 0.5em 0 position: relative &:before display: block diff --git a/client/css/post-list-view.styl b/client/css/post-list-view.styl index 10ebb6df..41d4cc13 100644 --- a/client/css/post-list-view.styl +++ b/client/css/post-list-view.styl @@ -145,12 +145,19 @@ margin-bottom: 0.75em * vertical-align: top + @media (max-width: 1000px) + display: block input margin-bottom: 0.25em margin-right: 0.25em input[name=search-text] width: 25em + @media (max-width: 1000px) + display: block + width: 100% + margin-bottom: 0.5em .append + vertical-align: middle font-size: 0.95em color: $inactive-link-color .bulk-edit @@ -163,6 +170,11 @@ &.hidden display: none .bulk-edit-tags + &.opened + .hint + @media (max-width: 1000px) + display: block + margin-bottom: 0.5em &:not(.opened) [type=text], .start @@ -171,8 +183,21 @@ display: none input[name=tag] width: 12em + @media (max-width: 1000px) + display: block + width: 100% + margin-bottom: 0.5em + .append + &.open, + &.hint + @media (max-width: 1000px) + margin-left: 0 .hint margin-right: 1em + .bulk-edit-safety + .append + @media (max-width: 1000px) + margin-left: 0 .safety margin-right: 0.25em diff --git a/client/css/post-main-view.styl b/client/css/post-main-view.styl index cfebf591..3c08745b 100644 --- a/client/css/post-main-view.styl +++ b/client/css/post-main-view.styl @@ -138,6 +138,9 @@ margin: 0 padding: 0 + form + width: auto + label:not(.file-dropper) margin-bottom: 0.3em display: block diff --git a/client/css/snapshots-list-view.styl b/client/css/snapshots-list-view.styl index a858a07d..b059a64a 100644 --- a/client/css/snapshots-list-view.styl +++ b/client/css/snapshots-list-view.styl @@ -8,11 +8,16 @@ $snapshot-merged-background-color = #FEC ul margin: 0 auto + padding: 0 width: 100% max-width: 35em list-style-type: none li + margin-bottom: 1em + &:last-child + margin-bottom: 0 + .time float: right @@ -39,6 +44,3 @@ $snapshot-merged-background-color = #FEC background: $snapshot-merged-background-color &+.details background: lighten($snapshot-merged-background-color, 50%) - - div.details - margin-bottom: 2em diff --git a/client/css/tag-categories-view.styl b/client/css/tag-categories-view.styl index d5f07627..b8a91802 100644 --- a/client/css/tag-categories-view.styl +++ b/client/css/tag-categories-view.styl @@ -17,6 +17,12 @@ text-align: center &.remove, &.set-default white-space: pre + th + white-space: nowrap + &:first-child + padding-left: 0 + &:last-child + padding-right: 0 tfoot display: none form diff --git a/client/css/tag-list-view.styl b/client/css/tag-list-view.styl index ac4b73de..4fd167f5 100644 --- a/client/css/tag-list-view.styl +++ b/client/css/tag-list-view.styl @@ -11,6 +11,7 @@ th, td padding: 0.1em 0.5em th + white-space: nowrap background: $top-navigation-color .names width: 28% @@ -47,6 +48,9 @@ width: auto input[name=search-text] width: 25em + @media (max-width: 1000px) + width: 100% .append + vertical-align: middle font-size: 0.95em color: $inactive-link-color diff --git a/client/css/user-list-view.styl b/client/css/user-list-view.styl index f575a97c..1fba50b5 100644 --- a/client/css/user-list-view.styl +++ b/client/css/user-list-view.styl @@ -34,6 +34,9 @@ width: auto input[name=search-text] width: 25em + @media (max-width: 1000px) + width: 100% .append + vertical-align: middle font-size: 0.95em color: $inactive-link-color diff --git a/client/html/tag_categories.tpl b/client/html/tag_categories.tpl index f401d515..fe6b8987 100644 --- a/client/html/tag_categories.tpl +++ b/client/html/tag_categories.tpl @@ -1,17 +1,19 @@

    Tag categories

    - - - - - - - - - - -
    Category nameCSS colorUsages
    +
    + + + + + + + + + + +
    Category nameCSS colorUsages
    +
    <% if (ctx.canCreate) { %>

    Add new category

    diff --git a/client/html/tags_page.tpl b/client/html/tags_page.tpl index be3b143a..8c973984 100644 --- a/client/html/tags_page.tpl +++ b/client/html/tags_page.tpl @@ -1,4 +1,4 @@ -
    +
    <% if (ctx.response.results.length) { %> diff --git a/client/html/top_navigation.tpl b/client/html/top_navigation.tpl index b0689820..f1af4610 100644 --- a/client/html/top_navigation.tpl +++ b/client/html/top_navigation.tpl @@ -1,5 +1,9 @@ diff --git a/client/html/user_tokens.tpl b/client/html/user_tokens.tpl new file mode 100644 index 00000000..73db7a17 --- /dev/null +++ b/client/html/user_tokens.tpl @@ -0,0 +1,74 @@ +
    +
    + <% if (ctx.tokens.length > 0) { %> +
    + <% _.each(ctx.tokens, function(token, index) { %> +
    +
    +
    Token:
    +
    Note:
    +
    Created:
    +
    Expires:
    +
    Last used:
    +
    +
    +
    <%= token.token %>
    +
    + <% if (token.note !== null) { %> + <%= token.note %> + <% } else { %> + No note + <% } %> + (change) +
    +
    <%= ctx.makeRelativeTime(token.creationTime) %>
    +
    + <% if (token.expirationTime) { %> + <%= ctx.makeRelativeTime(token.expirationTime) %> + <% } else { %> + No expiration + <% } %> +
    +
    <%= ctx.makeRelativeTime(token.lastUsageTime) %>
    +
    +
    +
    +
    +
    + + <% if (token.isCurrentAuthToken) { %> + + <% } else { %> + + <% } %> + +
    +
    +
    +
    + <% }); %> +
    + <% } else { %> +

    No Registered Tokens

    + <% } %> +
    +
      +
    • + <%= ctx.makeTextInput({ + text: 'Note', + id: 'note', + }) %> +
    • +
    • + <%= ctx.makeDateInput({ + text: 'Expires', + id: 'expirationTime', + }) %> +
    • +
    +
    + +
    + +
    diff --git a/client/js/api.js b/client/js/api.js index abfeb6f0..3623045b 100644 --- a/client/js/api.js +++ b/client/js/api.js @@ -15,6 +15,7 @@ class Api extends events.EventTarget { this.user = null; this.userName = null; this.userPassword = null; + this.token = null; this.cache = {}; this.allRanks = [ 'anonymous', @@ -87,11 +88,76 @@ class Api extends events.EventTarget { loginFromCookies() { const auth = cookies.getJSON('auth'); - return auth && auth.user && auth.password ? - this.login(auth.user, auth.password, true) : + return auth && auth.user && auth.token ? + this.loginWithToken(auth.user, auth.token, true) : Promise.resolve(); } + loginWithToken(userName, token, doRemember) { + this.cache = {}; + return new Promise((resolve, reject) => { + this.userName = userName; + this.token = token; + this.get('/user/' + userName + '?bump-login=true') + .then(response => { + const options = {}; + if (doRemember) { + options.expires = 365; + } + cookies.set( + 'auth', + {'user': userName, 'token': token}, + options); + this.user = response; + resolve(); + this.dispatchEvent(new CustomEvent('login')); + }, error => { + reject(error); + this.logout(); + }); + }); + } + + createToken(userName, options) { + let userTokenRequest = { + enabled: true, + note: 'Web Login Token' + }; + if (typeof options.expires !== 'undefined') { + userTokenRequest.expirationTime = new Date().addDays(options.expires).toISOString() + } + return new Promise((resolve, reject) => { + this.post('/user-token/' + userName, userTokenRequest) + .then(response => { + cookies.set( + 'auth', + {'user': userName, 'token': response.token}, + options); + this.userName = userName; + this.token = response.token; + this.userPassword = null; + }, error => { + reject(error); + }); + }); + } + + deleteToken(userName, userToken) { + return new Promise((resolve, reject) => { + this.delete('/user-token/' + userName + '/' + userToken, {}) + .then(response => { + const options = {}; + cookies.set( + 'auth', + {'user': userName, 'token': null}, + options); + resolve(); + }, error => { + reject(error); + }); + }); + } + login(userName, userPassword, doRemember) { this.cache = {}; return new Promise((resolve, reject) => { @@ -103,10 +169,7 @@ class Api extends events.EventTarget { if (doRemember) { options.expires = 365; } - cookies.set( - 'auth', - {'user': userName, 'password': userPassword}, - options); + this.createToken(this.userName, options); this.user = response; resolve(); this.dispatchEvent(new CustomEvent('login')); @@ -118,9 +181,20 @@ class Api extends events.EventTarget { } logout() { + let self = this; + this.deleteToken(this.userName, this.token) + .then(response => { + self._logout(); + }, error => { + self._logout(); + }); + } + + _logout() { this.user = null; this.userName = null; this.userPassword = null; + this.token = null; this.dispatchEvent(new CustomEvent('logout')); } @@ -137,6 +211,10 @@ class Api extends events.EventTarget { } } + isCurrentAuthToken(userToken) { + return userToken.token === this.token; + } + _getFullUrl(url) { const fullUrl = (config.apiUrl + '/' + url).replace(/([^:])\/+/g, '$1/'); @@ -258,7 +336,11 @@ class Api extends events.EventTarget { } try { - if (this.userName && this.userPassword) { + if (this.userName && this.token) { + req.auth = null; + req.set('Authorization', 'Token ' + + new Buffer(this.userName + ":" + this.token).toString('base64')) + } else if (this.userName && this.userPassword) { req.auth( this.userName, encodeURIComponent(this.userPassword) diff --git a/client/js/controllers/user_controller.js b/client/js/controllers/user_controller.js index 46020f37..d042e41f 100644 --- a/client/js/controllers/user_controller.js +++ b/client/js/controllers/user_controller.js @@ -7,6 +7,7 @@ const misc = require('../util/misc.js'); const config = require('../config.js'); const views = require('../util/views.js'); const User = require('../models/user.js'); +const UserToken = require('../models/user_token.js'); const topNavigation = require('../models/top_navigation.js'); const UserView = require('../views/user_view.js'); const EmptyView = require('../views/empty_view.js'); @@ -21,8 +22,28 @@ class UserController { return; } + this._successMessages = []; + this._errorMessages = []; + + let userTokenPromise = Promise.resolve([]); + if (section === 'list-tokens') { + userTokenPromise = UserToken.get(userName) + .then(userTokens => { + return userTokens.map(token => { + token.isCurrentAuthToken = api.isCurrentAuthToken(token); + return token; + }); + }, error => { + return []; + }); + } + topNavigation.setTitle('User ' + userName); - User.get(userName).then(user => { + Promise.all([ + userTokenPromise, + User.get(userName) + ]).then(responses => { + const [userTokens, user] = responses; const isLoggedIn = api.isLoggedIn(user); const infix = isLoggedIn ? 'self' : 'any'; @@ -48,6 +69,7 @@ class UserController { } else { topNavigation.activate('users'); } + this._view = new UserView({ user: user, section: section, @@ -58,18 +80,51 @@ class UserController { canEditRank: api.hasPrivilege(`users:edit:${infix}:rank`), canEditAvatar: api.hasPrivilege(`users:edit:${infix}:avatar`), canEditAnything: api.hasPrivilege(`users:edit:${infix}`), + canListTokens: api.hasPrivilege(`userTokens:list:${infix}`), + canCreateToken: api.hasPrivilege(`userTokens:create:${infix}`), + canEditToken: api.hasPrivilege(`userTokens:edit:${infix}`), + canDeleteToken: api.hasPrivilege(`userTokens:delete:${infix}`), canDelete: api.hasPrivilege(`users:delete:${infix}`), ranks: ranks, + tokens: userTokens, }); this._view.addEventListener('change', e => this._evtChange(e)); this._view.addEventListener('submit', e => this._evtUpdate(e)); this._view.addEventListener('delete', e => this._evtDelete(e)); + this._view.addEventListener('create-token', e => this._evtCreateToken(e)); + this._view.addEventListener('delete-token', e => this._evtDeleteToken(e)); + this._view.addEventListener('update-token', e => this._evtUpdateToken(e)); + + for (let message of this._successMessages) { + this.showSuccess(message); + } + + for (let message of this._errorMessages) { + this.showError(message); + } + }, error => { this._view = new EmptyView(); this._view.showError(error.message); }); } + showSuccess(message) { + if (typeof this._view === 'undefined') { + this._successMessages.push(message) + } else { + this._view.showSuccess(message); + } + } + + showError(message) { + if (typeof this._view === 'undefined') { + this._errorMessages.push(message) + } else { + this._view.showError(message); + } + } + _evtChange(e) { misc.enableExitConfirmation(); } @@ -148,6 +203,53 @@ class UserController { this._view.enableForm(); }); } + + _evtCreateToken(e) { + this._view.clearMessages(); + this._view.disableForm(); + UserToken.create(e.detail.user.name, e.detail.note, e.detail.expirationTime) + .then(response => { + const ctx = router.show(uri.formatClientLink('user', e.detail.user.name, 'list-tokens')); + ctx.controller.showSuccess('Token ' + response.token + ' created.'); + }, error => { + this._view.showError(error.message); + this._view.enableForm(); + }); + } + + _evtDeleteToken(e) { + this._view.clearMessages(); + this._view.disableForm(); + if (api.isCurrentAuthToken(e.detail.userToken)) { + router.show(uri.formatClientLink('logout')); + } else { + e.detail.userToken.delete(e.detail.user.name) + .then(() => { + const ctx = router.show(uri.formatClientLink('user', e.detail.user.name, 'list-tokens')); + ctx.controller.showSuccess('Token ' + e.detail.userToken.token + ' deleted.'); + }, error => { + this._view.showError(error.message); + this._view.enableForm(); + }); + } + } + + _evtUpdateToken(e) { + this._view.clearMessages(); + this._view.disableForm(); + + if (e.detail.note !== undefined) { + e.detail.userToken.note = e.detail.note; + } + + e.detail.userToken.save(e.detail.user.name).then(response => { + const ctx = router.show(uri.formatClientLink('user', e.detail.user.name, 'list-tokens')); + ctx.controller.showSuccess('Token ' + response.token + ' updated.'); + }, error => { + this._view.showError(error.message); + this._view.enableForm(); + }); + } } module.exports = router => { @@ -157,6 +259,9 @@ module.exports = router => { router.enter(['user', ':name', 'edit'], (ctx, next) => { ctx.controller = new UserController(ctx, 'edit'); }); + router.enter(['user', ':name', 'list-tokens'], (ctx, next) => { + ctx.controller = new UserController(ctx, 'list-tokens'); + }); router.enter(['user', ':name', 'delete'], (ctx, next) => { ctx.controller = new UserController(ctx, 'delete'); }); diff --git a/client/js/models/user_token.js b/client/js/models/user_token.js new file mode 100644 index 00000000..6e70a94b --- /dev/null +++ b/client/js/models/user_token.js @@ -0,0 +1,116 @@ +'use strict'; + +const api = require('../api.js'); +const uri = require('../util/uri.js'); +const events = require('../events.js'); + +class UserToken extends events.EventTarget { + constructor() { + super(); + this._orig = {}; + this._updateFromResponse({}); + } + + get token() { return this._token; } + get note() { return this._note; } + get enabled() { return this._enabled; } + get version() { return this._version; } + get expirationTime() { return this._expirationTime; } + get creationTime() { return this._creationTime; } + get lastEditTime() { return this._lastEditTime; } + get lastUsageTime() { return this._lastUsageTime; } + + set note(value) { this._note = value; } + + static fromResponse(response) { + if (typeof response.results !== 'undefined') { + let tokenList = []; + for (let responseToken of response.results) { + const token = new UserToken(); + token._updateFromResponse(responseToken); + tokenList.push(token) + } + return tokenList; + } else { + const ret = new UserToken(); + ret._updateFromResponse(response); + return ret; + } + } + + static get(userName) { + return api.get(uri.formatApiLink('user-tokens', userName)) + .then(response => { + return Promise.resolve(UserToken.fromResponse(response)); + }); + } + + static create(userName, note, expirationTime) { + let userTokenRequest = { + enabled: true + }; + if (note) { + userTokenRequest.note = note; + } + if (expirationTime) { + userTokenRequest.expirationTime = expirationTime; + } + return api.post(uri.formatApiLink('user-token', userName), userTokenRequest) + .then(response => { + return Promise.resolve(UserToken.fromResponse(response)) + }); + } + + save(userName) { + const detail = {version: this._version}; + + if (this._note !== this._orig._note) { + detail.note = this._note; + } + + return api.put( + uri.formatApiLink('user-token', userName, this._orig._token), + detail) + .then(response => { + this._updateFromResponse(response); + this.dispatchEvent(new CustomEvent('change', { + detail: { + userToken: this, + }, + })); + return Promise.resolve(this); + }); + } + + delete(userName) { + return api.delete( + uri.formatApiLink('user-token', userName, this._orig._token), + {version: this._version}) + .then(response => { + this.dispatchEvent(new CustomEvent('delete', { + detail: { + userToken: this, + }, + })); + return Promise.resolve(); + }); + } + + _updateFromResponse(response) { + const map = { + _token: response.token, + _note: response.note, + _enabled: response.enabled, + _expirationTime: response.expirationTime, + _version: response.version, + _creationTime: response.creationTime, + _lastEditTime: response.lastEditTime, + _lastUsageTime: response.lastUsageTime, + }; + + Object.assign(this, map); + Object.assign(this._orig, map); + } +} + +module.exports = UserToken; diff --git a/client/js/util/polyfill.js b/client/js/util/polyfill.js index 91186b2a..71ee9724 100644 --- a/client/js/util/polyfill.js +++ b/client/js/util/polyfill.js @@ -59,3 +59,10 @@ Number.prototype.between = function(a, b, inclusive) { // non standard Promise.prototype.abort = () => {}; + +// non standard +Date.prototype.addDays = function(days) { + let dat = new Date(this.valueOf()); + dat.setDate(dat.getDate() + days); + return dat; +}; diff --git a/client/js/util/views.js b/client/js/util/views.js index b0b7ccec..9f238b1e 100644 --- a/client/js/util/views.js +++ b/client/js/util/views.js @@ -168,6 +168,11 @@ function makeNumericInput(options) { return makeInput(options); } +function makeDateInput(options) { + options.type = 'date'; + return makeInput(options) +} + function getPostUrl(id, parameters) { return uri.formatClientLink( 'post', id, @@ -392,6 +397,7 @@ function getTemplate(templatePath) { makePasswordInput: makePasswordInput, makeEmailInput: makeEmailInput, makeColorInput: makeColorInput, + makeDateInput: makeDateInput, makePostLink: makePostLink, makeTagLink: makeTagLink, makeUserLink: makeUserLink, diff --git a/client/js/views/user_tokens_view.js b/client/js/views/user_tokens_view.js new file mode 100644 index 00000000..f6c84800 --- /dev/null +++ b/client/js/views/user_tokens_view.js @@ -0,0 +1,134 @@ +'use strict'; + +const events = require('../events.js'); +const views = require('../util/views.js'); + +const template = views.getTemplate('user-tokens'); + +class UserTokenView extends events.EventTarget { + constructor(ctx) { + super(); + + this._user = ctx.user; + this._tokens = ctx.tokens; + this._hostNode = ctx.hostNode; + this._tokenFormNodes = []; + views.replaceContent(this._hostNode, template(ctx)); + views.decorateValidator(this._formNode); + + this._formNode.addEventListener('submit', e => this._evtSubmit(e)); + + this._decorateTokenForms(); + this._decorateTokenNoteChangeLinks(); + } + + _decorateTokenForms() { + this._tokenFormNodes = []; + for (let i = 0; i < this._tokens.length; i++) { + let formNode = this._hostNode.querySelector( + '.token[data-token-id=\"' + i + '\"]'); + formNode.addEventListener('submit', e => this._evtDelete(e)); + this._tokenFormNodes.push(formNode); + } + } + + _decorateTokenNoteChangeLinks() { + for (let i = 0; i < this._tokens.length; i++) { + let linkNode = this._hostNode.querySelector( + '.token-change-note[data-token-id=\"' + i + '\"]'); + linkNode.addEventListener( + 'click', e => this._evtChangeNoteClick(e)); + } + } + + clearMessages() { + views.clearMessages(this._hostNode); + } + + showSuccess(message) { + views.showSuccess(this._hostNode, message); + } + + showError(message) { + views.showError(this._hostNode, message); + } + + enableForm() { + views.enableForm(this._formNode); + for (let formNode of this._tokenFormNodes) { + views.enableForm(formNode); + } + } + + disableForm() { + views.disableForm(this._formNode); + for (let formNode of this._tokenFormNodes) { + views.disableForm(formNode); + } + } + + _evtDelete(e) { + e.preventDefault(); + const userToken = this._tokens[parseInt( + e.target.getAttribute('data-token-id'))]; + this.dispatchEvent(new CustomEvent('delete', { + detail: { + user: this._user, + userToken: userToken, + }, + })); + } + + _evtSubmit(e) { + e.preventDefault(); + this.dispatchEvent(new CustomEvent('submit', { + detail: { + user: this._user, + + note: this._userTokenNoteInputNode ? + this._userTokenNoteInputNode.value : + undefined, + + expirationTime: + (this._userTokenExpirationTimeInputNode + && this._userTokenExpirationTimeInputNode.value) ? + new Date(this._userTokenExpirationTimeInputNode.value) + .toISOString() : + undefined, + }, + })); + } + + _evtChangeNoteClick(e) { + e.preventDefault(); + const userToken = this._tokens[ + parseInt(e.target.getAttribute('data-token-id'))]; + const text = window.prompt( + 'Please enter the new name:', + userToken.note !== null ? userToken.note : undefined); + if (!text) { + return; + } + this.dispatchEvent(new CustomEvent('update', { + detail: { + user: this._user, + userToken: userToken, + note: text ? text : undefined, + }, + })); + } + + get _formNode() { + return this._hostNode.querySelector('#create-token-form'); + } + + get _userTokenNoteInputNode() { + return this._formNode.querySelector('.note input'); + } + + get _userTokenExpirationTimeInputNode() { + return this._formNode.querySelector('.expirationTime input'); + } +} + +module.exports = UserTokenView; diff --git a/client/js/views/user_view.js b/client/js/views/user_view.js index d52e8999..75fd154d 100644 --- a/client/js/views/user_view.js +++ b/client/js/views/user_view.js @@ -3,6 +3,7 @@ const events = require('../events.js'); const views = require('../util/views.js'); const UserDeleteView = require('./user_delete_view.js'); +const UserTokensView = require('./user_tokens_view.js'); const UserSummaryView = require('./user_summary_view.js'); const UserEditView = require('./user_edit_view.js'); const EmptyView = require('../views/empty_view.js'); @@ -45,7 +46,17 @@ class UserView extends events.EventTarget { this._view = new UserEditView(ctx); events.proxyEvent(this._view, this, 'submit'); } - + } else if (ctx.section == 'list-tokens') { + if (!this._ctx.canListTokens) { + this._view = new EmptyView(); + this._view.showError( + 'You don\'t have privileges to view user tokens.'); + } else { + this._view = new UserTokensView(ctx); + events.proxyEvent(this._view, this, 'delete', 'delete-token'); + events.proxyEvent(this._view, this, 'submit', 'create-token'); + events.proxyEvent(this._view, this, 'update', 'update-token'); + } } else if (ctx.section == 'delete') { if (!this._ctx.canDelete) { this._view = new EmptyView(); diff --git a/config.yaml.dist b/config.yaml.dist index a8b0a1ff..297dc6fb 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -91,6 +91,15 @@ privileges: 'users:delete:any': administrator 'users:delete:self': regular + 'user_tokens:list:any': administrator + 'user_tokens:list:self': regular + 'user_tokens:create:any': administrator + 'user_tokens:create:self': regular + 'user_tokens:edit:any': administrator + 'user_tokens:edit:self': regular + 'user_tokens:delete:any': administrator + 'user_tokens:delete:self': regular + 'posts:create:anonymous': regular 'posts:create:identified': regular 'posts:list': anonymous diff --git a/server/requirements.txt b/server/requirements.txt index 7cc47868..b11c3b5d 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -11,4 +11,6 @@ scipy>=0.18.1 elasticsearch>=5.0.0 elasticsearch-dsl>=5.0.0 scikit-image>=0.12 -pynacl>=1.2.1 \ No newline at end of file +pynacl>=1.2.1 +pytz>=2018.3 +pyRFC3339>=1.0 \ No newline at end of file diff --git a/server/szurubooru/api/__init__.py b/server/szurubooru/api/__init__.py index 2a2d5af7..0d7f75f8 100644 --- a/server/szurubooru/api/__init__.py +++ b/server/szurubooru/api/__init__.py @@ -1,5 +1,6 @@ import szurubooru.api.info_api import szurubooru.api.user_api +import szurubooru.api.user_token_api import szurubooru.api.post_api import szurubooru.api.tag_api import szurubooru.api.tag_category_api diff --git a/server/szurubooru/api/user_token_api.py b/server/szurubooru/api/user_token_api.py new file mode 100644 index 00000000..77398239 --- /dev/null +++ b/server/szurubooru/api/user_token_api.py @@ -0,0 +1,83 @@ +from typing import Dict +from szurubooru import model, rest +from szurubooru.func import auth, users, user_tokens, serialization, versions + + +def _serialize( + ctx: rest.Context, user_token: model.UserToken) -> rest.Response: + return user_tokens.serialize_user_token( + user_token, + ctx.user, + options=serialization.get_serialization_options(ctx)) + + +@rest.routes.get('/user-tokens/(?P[^/]+)/?') +def get_user_tokens( + ctx: rest.Context, params: Dict[str, str] = {}) -> rest.Response: + user = users.get_user_by_name(params['user_name']) + infix = 'self' if ctx.user.user_id == user.user_id else 'any' + auth.verify_privilege(ctx.user, 'user_tokens:list:%s' % infix) + user_token_list = user_tokens.get_user_tokens(user) + return { + 'results': [_serialize(ctx, token) for token in user_token_list] + } + + +@rest.routes.post('/user-token/(?P[^/]+)/?') +def create_user_token( + ctx: rest.Context, params: Dict[str, str] = {}) -> rest.Response: + user = users.get_user_by_name(params['user_name']) + infix = 'self' if ctx.user.user_id == user.user_id else 'any' + auth.verify_privilege(ctx.user, 'user_tokens:create:%s' % infix) + enabled = ctx.get_param_as_bool('enabled', True) + user_token = user_tokens.create_user_token(user, enabled) + if ctx.has_param('note'): + note = ctx.get_param_as_string('note') + user_tokens.update_user_token_note(user_token, note) + if ctx.has_param('expirationTime'): + expiration_time = ctx.get_param_as_string('expirationTime') + user_tokens.update_user_token_expiration_time( + user_token, expiration_time) + ctx.session.add(user_token) + ctx.session.commit() + return _serialize(ctx, user_token) + + +@rest.routes.put('/user-token/(?P[^/]+)/(?P[^/]+)/?') +def update_user_token( + ctx: rest.Context, params: Dict[str, str] = {}) -> rest.Response: + user = users.get_user_by_name(params['user_name']) + infix = 'self' if ctx.user.user_id == user.user_id else 'any' + auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix) + user_token = user_tokens.get_by_user_and_token(user, params['user_token']) + versions.verify_version(user_token, ctx) + versions.bump_version(user_token) + if ctx.has_param('enabled'): + auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix) + user_tokens.update_user_token_enabled( + user_token, ctx.get_param_as_bool('enabled')) + if ctx.has_param('note'): + auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix) + note = ctx.get_param_as_string('note') + user_tokens.update_user_token_note(user_token, note) + if ctx.has_param('expirationTime'): + auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix) + expiration_time = ctx.get_param_as_string('expirationTime') + user_tokens.update_user_token_expiration_time( + user_token, expiration_time) + user_tokens.update_user_token_edit_time(user_token) + ctx.session.commit() + return _serialize(ctx, user_token) + + +@rest.routes.delete('/user-token/(?P[^/]+)/(?P[^/]+)/?') +def delete_user_token( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + user = users.get_user_by_name(params['user_name']) + infix = 'self' if ctx.user.user_id == user.user_id else 'any' + auth.verify_privilege(ctx.user, 'user_tokens:delete:%s' % infix) + user_token = user_tokens.get_by_user_and_token(user, params['user_token']) + if user_token is not None: + ctx.session.delete(user_token) + ctx.session.commit() + return {} diff --git a/server/szurubooru/func/auth.py b/server/szurubooru/func/auth.py index c9740fe0..65be79ac 100644 --- a/server/szurubooru/func/auth.py +++ b/server/szurubooru/func/auth.py @@ -1,10 +1,12 @@ -from typing import Tuple +from typing import Tuple, Optional import hashlib import random +import uuid from collections import OrderedDict +from datetime import datetime from nacl import pwhash from nacl.exceptions import InvalidkeyError -from szurubooru import config, model, errors, db +from szurubooru import config, db, model, errors from szurubooru.func import util @@ -26,7 +28,8 @@ def get_password_hash(salt: str, password: str) -> Tuple[str, int]: ).decode('utf8'), 3 -def get_sha256_legacy_password_hash(salt: str, password: str) -> Tuple[str, int]: +def get_sha256_legacy_password_hash( + salt: str, password: str) -> Tuple[str, int]: ''' Retrieve old-style sha256 password hash. ''' digest = hashlib.sha256() digest.update(config.config['secret'].encode('utf8')) @@ -78,6 +81,21 @@ def is_valid_password(user: model.User, password: str) -> bool: return False +def is_valid_token(user_token: Optional[model.UserToken]) -> bool: + ''' + Token must be enabled and if it has an expiration, it must be + greater than now. + ''' + if user_token is None: + return False + if not user_token.enabled: + return False + if (user_token.expiration_time is not None + and user_token.expiration_time < datetime.utcnow()): + return False + return True + + def has_privilege(user: model.User, privilege_name: str) -> bool: assert user all_ranks = list(RANK_MAP.keys()) @@ -102,3 +120,7 @@ def generate_authentication_token(user: model.User) -> str: digest.update(config.config['secret'].encode('utf8')) digest.update(user.password_salt.encode('utf8')) return digest.hexdigest() + + +def generate_authorization_token() -> str: + return uuid.uuid4().__str__() diff --git a/server/szurubooru/func/user_tokens.py b/server/szurubooru/func/user_tokens.py new file mode 100644 index 00000000..c0f4badb --- /dev/null +++ b/server/szurubooru/func/user_tokens.py @@ -0,0 +1,146 @@ +from datetime import datetime +from typing import Any, Optional, List, Dict, Callable +from pyrfc3339 import parser as rfc3339_parser +import pytz +from szurubooru import db, model, rest, errors +from szurubooru.func import auth, serialization, users, util + + +class InvalidExpirationError(errors.ValidationError): + pass + + +class InvalidNoteError(errors.ValidationError): + pass + + +class UserTokenSerializer(serialization.BaseSerializer): + def __init__( + self, + user_token: model.UserToken, + auth_user: model.User) -> None: + self.user_token = user_token + self.auth_user = auth_user + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'user': self.serialize_user, + 'token': self.serialize_token, + 'note': self.serialize_note, + 'enabled': self.serialize_enabled, + 'expirationTime': self.serialize_expiration_time, + 'creationTime': self.serialize_creation_time, + 'lastEditTime': self.serialize_last_edit_time, + 'lastUsageTime': self.serialize_last_usage_time, + 'version': self.serialize_version, + } + + def serialize_user(self) -> Any: + return users.serialize_micro_user(self.user_token.user, self.auth_user) + + def serialize_creation_time(self) -> Any: + return self.user_token.creation_time + + def serialize_last_edit_time(self) -> Any: + return self.user_token.last_edit_time + + def serialize_last_usage_time(self) -> Any: + return self.user_token.last_usage_time + + def serialize_token(self) -> Any: + return self.user_token.token + + def serialize_note(self) -> Any: + return self.user_token.note + + def serialize_enabled(self) -> Any: + return self.user_token.enabled + + def serialize_expiration_time(self) -> Any: + return self.user_token.expiration_time + + def serialize_version(self) -> Any: + return self.user_token.version + + +def serialize_user_token( + user_token: Optional[model.UserToken], + auth_user: model.User, + options: List[str] = []) -> Optional[rest.Response]: + if not user_token: + return None + return UserTokenSerializer(user_token, auth_user).serialize(options) + + +def get_by_user_and_token( + user: model.User, token: str) -> model.UserToken: + return ( + db.session + .query(model.UserToken) + .filter(model.UserToken.user_id == user.user_id) + .filter(model.UserToken.token == token) + .one_or_none()) + + +def get_user_tokens(user: model.User) -> List[model.UserToken]: + assert user + return ( + db.session + .query(model.UserToken) + .filter(model.UserToken.user_id == user.user_id) + .all()) + + +def create_user_token(user: model.User, enabled: bool) -> model.UserToken: + assert user + user_token = model.UserToken() + user_token.user = user + user_token.token = auth.generate_authorization_token() + user_token.enabled = enabled + user_token.creation_time = datetime.utcnow() + user_token.last_usage_time = datetime.utcnow() + return user_token + + +def update_user_token_enabled( + user_token: model.UserToken, enabled: bool) -> None: + assert user_token + user_token.enabled = enabled + update_user_token_edit_time(user_token) + + +def update_user_token_edit_time(user_token: model.UserToken) -> None: + assert user_token + user_token.last_edit_time = datetime.utcnow() + + +def update_user_token_expiration_time( + user_token: model.UserToken, expiration_time_str: str) -> None: + assert user_token + try: + expiration_time = rfc3339_parser.parse(expiration_time_str, utc=True) + expiration_time = expiration_time.astimezone(pytz.UTC) + if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC): + raise InvalidExpirationError( + 'Expiration cannot happen in the past') + user_token.expiration_time = expiration_time + update_user_token_edit_time(user_token) + except ValueError: + raise InvalidExpirationError( + 'Expiration is in an invalid format {}'.format( + expiration_time_str)) + + +def update_user_token_note(user_token: model.UserToken, note: str) -> None: + assert user_token + note = note.strip() if note is not None else '' + note = None if len(note) == 0 else note + if util.value_exceeds_column_size(note, model.UserToken.note): + raise InvalidNoteError('Note is too long.') + user_token.note = note + update_user_token_edit_time(user_token) + + +def bump_usage_time(user_token: model.UserToken) -> None: + assert user_token + user_token.last_usage_time = datetime.utcnow() diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index 012debca..e5946dc9 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -1,6 +1,6 @@ -import re -from typing import Any, Optional, Union, List, Dict, Callable from datetime import datetime +from typing import Any, Optional, Union, List, Dict, Callable +import re import sqlalchemy as sa from szurubooru import config, db, model, errors, rest from szurubooru.func import auth, util, serialization, files, images diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index ba2d4dc9..5e822866 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -160,6 +160,12 @@ def value_exceeds_column_size(value: Optional[str], column: Any) -> bool: return len(value) > max_length +def get_column_size(column: Any) -> Optional[int]: + if not column: + return None + return column.property.columns[0].type.length + + def chunks(source_list: List[Any], part_size: int) -> Generator: for i in range(0, len(source_list), part_size): yield source_list[i:i + part_size] diff --git a/server/szurubooru/middleware/authenticator.py b/server/szurubooru/middleware/authenticator.py index 644fe3b3..4340ec94 100644 --- a/server/szurubooru/middleware/authenticator.py +++ b/server/szurubooru/middleware/authenticator.py @@ -1,11 +1,11 @@ import base64 -from typing import Optional -from szurubooru import db, model, errors, rest -from szurubooru.func import auth, users +from typing import Optional, Tuple +from szurubooru import model, errors, rest +from szurubooru.func import auth, users, user_tokens from szurubooru.rest.errors import HttpBadRequest -def _authenticate(username: str, password: str) -> model.User: +def _authenticate_basic_auth(username: str, password: str) -> model.User: ''' Try to authenticate user. Throw AuthError for invalid users. ''' user = users.get_user_by_name(username) if not auth.is_valid_password(user, password): @@ -13,34 +13,61 @@ def _authenticate(username: str, password: str) -> model.User: return user -def _get_user(ctx: rest.Context) -> Optional[model.User]: +def _authenticate_token( + username: str, token: str) -> Tuple[model.User, model.UserToken]: + ''' Try to authenticate user. Throw AuthError for invalid users. ''' + user = users.get_user_by_name(username) + user_token = user_tokens.get_by_user_and_token(user, token) + if not auth.is_valid_token(user_token): + raise errors.AuthError('Invalid token.') + return user, user_token + + +def _get_user(ctx: rest.Context, bump_login: bool) -> Optional[model.User]: if not ctx.has_header('Authorization'): return None + auth_token = None + try: auth_type, credentials = ctx.get_header('Authorization').split(' ', 1) - if auth_type.lower() != 'basic': + if auth_type.lower() == 'basic': + username, password = base64.decodebytes( + credentials.encode('ascii')).decode('utf8').split(':', 1) + auth_user = _authenticate_basic_auth(username, password) + elif auth_type.lower() == 'token': + username, token = base64.decodebytes( + credentials.encode('ascii')).decode('utf8').split(':', 1) + auth_user, auth_token = _authenticate_token(username, token) + else: raise HttpBadRequest( 'ValidationError', - 'Only basic HTTP authentication is supported.') - username, password = base64.decodebytes( - credentials.encode('ascii')).decode('utf8').split(':', 1) - return _authenticate(username, password) + 'Only basic or token HTTP authentication is supported.') except ValueError as err: msg = ( - 'Basic authentication header value are not properly formed. ' + 'Authorization header values are not properly formed. ' 'Supplied header {0}. Got error: {1}') raise HttpBadRequest( 'ValidationError', msg.format(ctx.get_header('Authorization'), str(err))) + if bump_login and auth_user.user_id: + users.bump_user_login_time(auth_user) + if auth_token is not None: + user_tokens.bump_usage_time(auth_token) + ctx.session.commit() + + return auth_user + -@rest.middleware.pre_hook def process_request(ctx: rest.Context) -> None: ''' Bind the user to request. Update last login time if needed. ''' - auth_user = _get_user(ctx) + bump_login = ctx.get_param_as_bool('bump-login', default=False) + auth_user = _get_user(ctx, bump_login) if auth_user: ctx.user = auth_user - if ctx.get_param_as_bool('bump-login', default=False) and ctx.user.user_id: - users.bump_user_login_time(ctx.user) - ctx.session.commit() + + +@rest.middleware.pre_hook +def process_request_hook(ctx: rest.Context) -> None: + process_request(ctx) diff --git a/server/szurubooru/migrations/versions/a39c7f98a7fa_add_user_token_table.py b/server/szurubooru/migrations/versions/a39c7f98a7fa_add_user_token_table.py new file mode 100644 index 00000000..899eaa70 --- /dev/null +++ b/server/szurubooru/migrations/versions/a39c7f98a7fa_add_user_token_table.py @@ -0,0 +1,39 @@ +''' +Added a user_token table for API authorization + +Revision ID: a39c7f98a7fa +Created at: 2018-02-25 01:31:27.345595 +''' + +import sqlalchemy as sa +from alembic import op + + +revision = 'a39c7f98a7fa' +down_revision = '9ef1a1643c2a' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'user_token', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('token', sa.Unicode(length=36), nullable=False), + sa.Column('note', sa.Unicode(length=128), nullable=True), + sa.Column('enabled', sa.Boolean(), nullable=False), + sa.Column('expiration_time', sa.DateTime(), nullable=True), + sa.Column('creation_time', sa.DateTime(), nullable=False), + sa.Column('last_edit_time', sa.DateTime(), nullable=True), + sa.Column('last_usage_time', sa.DateTime(), nullable=True), + sa.Column('version', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id')) + op.create_index( + op.f('ix_user_token_user_id'), 'user_token', ['user_id'], unique=False) + + +def downgrade(): + op.drop_index(op.f('ix_user_token_user_id'), table_name='user_token') + op.drop_table('user_token') diff --git a/server/szurubooru/model/__init__.py b/server/szurubooru/model/__init__.py index ad2231c2..4892b974 100644 --- a/server/szurubooru/model/__init__.py +++ b/server/szurubooru/model/__init__.py @@ -1,5 +1,5 @@ from szurubooru.model.base import Base -from szurubooru.model.user import User +from szurubooru.model.user import User, UserToken from szurubooru.model.tag_category import TagCategory from szurubooru.model.tag import Tag, TagName, TagSuggestion, TagImplication from szurubooru.model.post import ( diff --git a/server/szurubooru/model/user.py b/server/szurubooru/model/user.py index 39c5a91b..2d599e85 100644 --- a/server/szurubooru/model/user.py +++ b/server/szurubooru/model/user.py @@ -86,3 +86,25 @@ class User(Base): 'version_id_col': version, 'version_id_generator': False, } + + +class UserToken(Base): + __tablename__ = 'user_token' + + user_token_id = sa.Column('id', sa.Integer, primary_key=True) + user_id = sa.Column( + 'user_id', + sa.Integer, + sa.ForeignKey('user.id', ondelete='CASCADE'), + nullable=False, + index=True) + token = sa.Column('token', sa.Unicode(36), nullable=False) + note = sa.Column('note', sa.Unicode(128), nullable=True) + enabled = sa.Column('enabled', sa.Boolean, nullable=False, default=True) + expiration_time = sa.Column('expiration_time', sa.DateTime, nullable=True) + creation_time = sa.Column('creation_time', sa.DateTime, nullable=False) + last_edit_time = sa.Column('last_edit_time', sa.DateTime) + last_usage_time = sa.Column('last_usage_time', sa.DateTime) + version = sa.Column('version', sa.Integer, default=1, nullable=False) + + user = sa.orm.relationship('User') diff --git a/server/szurubooru/rest/__init__.py b/server/szurubooru/rest/__init__.py index 14a3e305..d6b3ef28 100644 --- a/server/szurubooru/rest/__init__.py +++ b/server/szurubooru/rest/__init__.py @@ -1,2 +1,3 @@ from szurubooru.rest.app import application from szurubooru.rest.context import Context, Response +import szurubooru.rest.routes diff --git a/server/szurubooru/tests/api/test_user_token_creating.py b/server/szurubooru/tests/api/test_user_token_creating.py new file mode 100644 index 00000000..f550f63f --- /dev/null +++ b/server/szurubooru/tests/api/test_user_token_creating.py @@ -0,0 +1,29 @@ +from unittest.mock import patch +import pytest +from szurubooru import api +from szurubooru.func import user_tokens, users + + +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'user_tokens:create:self': 'regular'}}) + + +def test_creating_user_token( + user_token_factory, context_factory, fake_datetime): + user_token = user_token_factory() + with patch('szurubooru.func.user_tokens.create_user_token'), \ + patch('szurubooru.func.user_tokens.serialize_user_token'), \ + patch('szurubooru.func.users.get_user_by_name'), \ + fake_datetime('1969-02-12'): + users.get_user_by_name.return_value = user_token.user + user_tokens.serialize_user_token.return_value = 'serialized user token' + user_tokens.create_user_token.return_value = user_token + result = api.user_token_api.create_user_token( + context_factory(user=user_token.user), + { + 'user_name': user_token.user.name + }) + assert result == 'serialized user token' + user_tokens.create_user_token.assert_called_once_with( + user_token.user, True) diff --git a/server/szurubooru/tests/api/test_user_token_deleting.py b/server/szurubooru/tests/api/test_user_token_deleting.py new file mode 100644 index 00000000..85341522 --- /dev/null +++ b/server/szurubooru/tests/api/test_user_token_deleting.py @@ -0,0 +1,30 @@ +from unittest.mock import patch +import pytest +from szurubooru import api, db +from szurubooru.func import user_tokens, users + + +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'user_tokens:delete:self': 'regular'}}) + + +def test_deleting_user_token( + user_token_factory, context_factory, fake_datetime): + user_token = user_token_factory() + db.session.add(user_token) + db.session.commit() + with patch('szurubooru.func.user_tokens.get_by_user_and_token'), \ + patch('szurubooru.func.users.get_user_by_name'), \ + fake_datetime('1969-02-12'): + users.get_user_by_name.return_value = user_token.user + user_tokens.get_by_user_and_token.return_value = user_token + result = api.user_token_api.delete_user_token( + context_factory(user=user_token.user), + { + 'user_name': user_token.user.name, + 'user_token': user_token.token + }) + assert result == {} + user_tokens.get_by_user_and_token.assert_called_once_with( + user_token.user, user_token.token) diff --git a/server/szurubooru/tests/api/test_user_token_retrieving.py b/server/szurubooru/tests/api/test_user_token_retrieving.py new file mode 100644 index 00000000..01b25342 --- /dev/null +++ b/server/szurubooru/tests/api/test_user_token_retrieving.py @@ -0,0 +1,31 @@ +from unittest.mock import patch +import pytest +from szurubooru import api +from szurubooru.func import user_tokens, users + + +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'user_tokens:list:self': 'regular'}}) + + +def test_retrieving_user_tokens( + user_token_factory, context_factory, fake_datetime): + user_token1 = user_token_factory() + user_token2 = user_token_factory(user=user_token1.user) + user_token3 = user_token_factory(user=user_token1.user) + with patch('szurubooru.func.user_tokens.get_user_tokens'), \ + patch('szurubooru.func.user_tokens.serialize_user_token'), \ + patch('szurubooru.func.users.get_user_by_name'), \ + fake_datetime('1969-02-12'): + users.get_user_by_name.return_value = user_token1.user + user_tokens.serialize_user_token.return_value = 'serialized user token' + user_tokens.get_user_tokens.return_value = [user_token1, user_token2, + user_token3] + result = api.user_token_api.get_user_tokens( + context_factory(user=user_token1.user), + { + 'user_name': user_token1.user.name + }) + assert result == {'results': ['serialized user token'] * 3} + user_tokens.get_user_tokens.assert_called_once_with(user_token1.user) diff --git a/server/szurubooru/tests/api/test_user_token_updating.py b/server/szurubooru/tests/api/test_user_token_updating.py new file mode 100644 index 00000000..bf725a35 --- /dev/null +++ b/server/szurubooru/tests/api/test_user_token_updating.py @@ -0,0 +1,42 @@ +from unittest.mock import patch +import pytest +from szurubooru import api, db +from szurubooru.func import user_tokens, users + + +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'user_tokens:edit:self': 'regular'}}) + + +def test_edit_user_token(user_token_factory, context_factory, fake_datetime): + user_token = user_token_factory() + db.session.add(user_token) + db.session.commit() + with patch('szurubooru.func.user_tokens.get_by_user_and_token'), \ + patch('szurubooru.func.user_tokens.update_user_token_enabled'), \ + patch('szurubooru.func.user_tokens.update_user_token_edit_time'), \ + patch('szurubooru.func.user_tokens.serialize_user_token'), \ + patch('szurubooru.func.users.get_user_by_name'), \ + fake_datetime('1969-02-12'): + users.get_user_by_name.return_value = user_token.user + user_tokens.serialize_user_token.return_value = 'serialized user token' + user_tokens.get_by_user_and_token.return_value = user_token + result = api.user_token_api.update_user_token( + context_factory( + params={ + 'version': user_token.version, + 'enabled': False, + }, + user=user_token.user), + { + 'user_name': user_token.user.name, + 'user_token': user_token.token + }) + assert result == 'serialized user token' + user_tokens.get_by_user_and_token.assert_called_once_with( + user_token.user, user_token.token) + user_tokens.update_user_token_enabled.assert_called_once_with( + user_token, False) + user_tokens.update_user_token_edit_time.assert_called_once_with( + user_token) diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index db7806e8..f6eeee18 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -93,11 +93,11 @@ def session(query_logger): # pylint: disable=unused-argument @pytest.fixture def context_factory(session): - def factory(params=None, files=None, user=None): + def factory(params=None, files=None, user=None, headers=None): ctx = rest.Context( method=None, url=None, - headers={}, + headers=headers or {}, params=params or {}, files=files or {}) ctx.session = session @@ -133,6 +133,27 @@ def user_factory(): return factory +@pytest.fixture +def user_token_factory(user_factory): + def factory( + user=None, + token=None, + expiration_time=None, + enabled=None, + creation_time=None): + if user is None: + user = user_factory() + db.session.add(user) + user_token = model.UserToken() + user_token.user = user + user_token.token = token or 'dummy' + user_token.expiration_time = expiration_time + user_token.enabled = enabled if enabled is not None else True + user_token.creation_time = creation_time or datetime(1997, 1, 1) + return user_token + return factory + + @pytest.fixture def tag_category_factory(): def factory(name=None, color='dummy', default=False): diff --git a/server/szurubooru/tests/func/test_auth.py b/server/szurubooru/tests/func/test_auth.py index 5d0955ad..6dc79bb5 100644 --- a/server/szurubooru/tests/func/test_auth.py +++ b/server/szurubooru/tests/func/test_auth.py @@ -1,5 +1,6 @@ -from szurubooru.func import auth +from datetime import datetime, timedelta import pytest +from szurubooru.func import auth @pytest.fixture(autouse=True) @@ -41,3 +42,24 @@ def test_is_valid_password_auto_upgrades_user_password_hash(user_factory): assert result is True assert user.password_hash != hash assert user.password_revision > revision + + +def test_is_valid_token(user_token_factory): + user_token = user_token_factory() + assert auth.is_valid_token(user_token) + + +def test_expired_token_is_invalid(user_token_factory): + past_expiration = (datetime.utcnow() - timedelta(minutes=30)) + user_token = user_token_factory(expiration_time=past_expiration) + assert not auth.is_valid_token(user_token) + + +def test_disabled_token_is_invalid(user_token_factory): + user_token = user_token_factory(enabled=False) + assert not auth.is_valid_token(user_token) + + +def test_generate_authorization_token(): + result = auth.generate_authorization_token() + assert result != auth.generate_authorization_token() diff --git a/server/szurubooru/tests/func/test_user_tokens.py b/server/szurubooru/tests/func/test_user_tokens.py new file mode 100644 index 00000000..8c3577c8 --- /dev/null +++ b/server/szurubooru/tests/func/test_user_tokens.py @@ -0,0 +1,155 @@ +from datetime import datetime, timedelta +from unittest.mock import patch +import pytest +import pytz +import random +import string +from szurubooru import db, model +from szurubooru.func import user_tokens, users, auth, util + + +def test_serialize_user_token(user_token_factory): + user_token = user_token_factory() + db.session.add(user_token) + db.session.flush() + with patch('szurubooru.func.users.get_avatar_url'): + users.get_avatar_url.return_value = 'https://example.com/avatar.png' + result = user_tokens.serialize_user_token(user_token, user_token.user) + assert result == { + 'creationTime': datetime(1997, 1, 1, 0, 0), + 'enabled': True, + 'expirationTime': None, + 'lastEditTime': None, + 'lastUsageTime': None, + 'note': None, + 'token': 'dummy', + 'user': { + 'avatarUrl': 'https://example.com/avatar.png', + 'name': user_token.user.name}, + 'version': 1 + } + + +def test_serialize_user_token_none(): + result = user_tokens.serialize_user_token(None, None) + assert result is None + + +def test_get_by_user_and_token(user_token_factory): + user_token = user_token_factory() + db.session.add(user_token) + db.session.flush() + db.session.commit() + result = user_tokens.get_by_user_and_token( + user_token.user, user_token.token) + assert result == user_token + + +def test_get_user_tokens(user_token_factory): + user_token1 = user_token_factory() + user_token2 = user_token_factory(user=user_token1.user) + db.session.add(user_token1) + db.session.add(user_token2) + db.session.flush() + db.session.commit() + result = user_tokens.get_user_tokens(user_token1.user) + assert result == [user_token1, user_token2] + + +def test_create_user_token(user_factory): + user = user_factory() + db.session.add(user) + db.session.flush() + db.session.commit() + with patch('szurubooru.func.auth.generate_authorization_token'): + auth.generate_authorization_token.return_value = 'test' + result = user_tokens.create_user_token(user, True) + assert result.token == 'test' + assert result.user == user + + +def test_update_user_token_enabled(user_token_factory): + user_token = user_token_factory() + user_tokens.update_user_token_enabled(user_token, False) + assert user_token.enabled is False + assert user_token.last_edit_time is not None + + +def test_update_user_token_edit_time(user_token_factory): + user_token = user_token_factory() + assert user_token.last_edit_time is None + user_tokens.update_user_token_edit_time(user_token) + assert user_token.last_edit_time is not None + + +def test_update_user_token_note(user_token_factory): + user_token = user_token_factory() + assert user_token.note is None + user_tokens.update_user_token_note(user_token, ' Test Note ') + assert user_token.note == 'Test Note' + assert user_token.last_edit_time is not None + + +def test_update_user_token_note_input_too_long(user_token_factory): + user_token = user_token_factory() + assert user_token.note is None + note_max_length = util.get_column_size(model.UserToken.note) + 1 + note = ''.join( + random.choice(string.ascii_letters) for _ in range(note_max_length)) + with pytest.raises(user_tokens.InvalidNoteError): + user_tokens.update_user_token_note(user_token, note) + + +def test_update_user_token_expiration_time(user_token_factory): + user_token = user_token_factory() + assert user_token.expiration_time is None + expiration_time_str = ( + (datetime.utcnow() + timedelta(days=1)) + .replace(tzinfo=pytz.utc) + ).isoformat() + user_tokens.update_user_token_expiration_time( + user_token, expiration_time_str) + assert user_token.expiration_time.isoformat() == expiration_time_str + assert user_token.last_edit_time is not None + + +def test_update_user_token_expiration_time_in_past(user_token_factory): + user_token = user_token_factory() + assert user_token.expiration_time is None + expiration_time_str = ( + (datetime.utcnow() - timedelta(days=1)) + .replace(tzinfo=pytz.utc) + ).isoformat() + with pytest.raises( + user_tokens.InvalidExpirationError, + match='Expiration cannot happen in the past'): + user_tokens.update_user_token_expiration_time( + user_token, expiration_time_str) + + +@pytest.mark.parametrize('expiration_time_str', [ + datetime.utcnow().isoformat(), + (datetime.utcnow() - timedelta(days=1)).ctime(), + '1970/01/01 00:00:01.0000Z', + '70/01/01 00:00:01.0000Z', + ''.join(random.choice(string.ascii_letters) for _ in range(15)), + ''.join(random.choice(string.digits) for _ in range(8)) +]) +def test_update_user_token_expiration_time_invalid_format( + expiration_time_str, user_token_factory): + user_token = user_token_factory() + assert user_token.expiration_time is None + + with pytest.raises( + user_tokens.InvalidExpirationError, + match='Expiration is in an invalid format %s' + % expiration_time_str): + user_tokens.update_user_token_expiration_time( + user_token, expiration_time_str) + + +def test_bump_usage_time(user_token_factory, fake_datetime): + user_token = user_token_factory() + with fake_datetime('1997-01-01'): + user_tokens.bump_usage_time(user_token) + assert user_token.last_usage_time == datetime(1997, 1, 1) diff --git a/server/szurubooru/tests/middleware/__init__.py b/server/szurubooru/tests/middleware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/szurubooru/tests/middleware/test_authenticator.py b/server/szurubooru/tests/middleware/test_authenticator.py new file mode 100644 index 00000000..be21a931 --- /dev/null +++ b/server/szurubooru/tests/middleware/test_authenticator.py @@ -0,0 +1,93 @@ +from unittest.mock import patch +import pytest +from szurubooru import db +from szurubooru.func import auth, users, user_tokens +from szurubooru.middleware import authenticator +from szurubooru.rest import errors + + +def test_process_request_no_header(context_factory): + ctx = context_factory() + authenticator.process_request(ctx) + assert ctx.user.name is None + + +def test_process_request_bump_login(context_factory, user_factory): + user = user_factory() + db.session.add(user) + db.session.flush() + ctx = context_factory( + headers={ + 'Authorization': 'Basic dGVzdFVzZXI6dGVzdFRva2Vu' + }, + params={ + 'bump-login': 'true' + }) + with patch('szurubooru.func.auth.is_valid_password'), \ + patch('szurubooru.func.users.get_user_by_name'): + users.get_user_by_name.return_value = user + auth.is_valid_password.return_value = True + authenticator.process_request(ctx) + assert user.last_login_time is not None + + +def test_process_request_bump_login_with_token( + context_factory, user_token_factory): + user_token = user_token_factory() + db.session.add(user_token) + db.session.flush() + ctx = context_factory( + headers={ + 'Authorization': 'Token dGVzdFVzZXI6dGVzdFRva2Vu' + }, + params={ + 'bump-login': 'true' + }) + with patch('szurubooru.func.auth.is_valid_token'), \ + patch('szurubooru.func.users.get_user_by_name'), \ + patch('szurubooru.func.user_tokens.get_by_user_and_token'): + users.get_user_by_name.return_value = user_token.user + user_tokens.get_by_user_and_token.return_value = user_token + auth.is_valid_token.return_value = True + authenticator.process_request(ctx) + assert user_token.user.last_login_time is not None + assert user_token.last_usage_time is not None + + +def test_process_request_basic_auth_valid(context_factory, user_factory): + user = user_factory() + ctx = context_factory( + headers={ + 'Authorization': 'Basic dGVzdFVzZXI6dGVzdFBhc3N3b3Jk' + }) + with patch('szurubooru.func.auth.is_valid_password'), \ + patch('szurubooru.func.users.get_user_by_name'): + users.get_user_by_name.return_value = user + auth.is_valid_password.return_value = True + authenticator.process_request(ctx) + assert ctx.user == user + + +def test_process_request_token_auth_valid(context_factory, user_token_factory): + user_token = user_token_factory() + ctx = context_factory( + headers={ + 'Authorization': 'Token dGVzdFVzZXI6dGVzdFRva2Vu' + }) + with patch('szurubooru.func.auth.is_valid_token'), \ + patch('szurubooru.func.users.get_user_by_name'), \ + patch('szurubooru.func.user_tokens.get_by_user_and_token'): + users.get_user_by_name.return_value = user_token.user + user_tokens.get_by_user_and_token.return_value = user_token + auth.is_valid_token.return_value = True + authenticator.process_request(ctx) + assert ctx.user == user_token.user + + +def test_process_request_bad_header(context_factory): + ctx = context_factory( + headers={ + 'Authorization': 'Secret SuperSecretValue' + }) + with pytest.raises(errors.HttpBadRequest): + authenticator.process_request(ctx) diff --git a/server/szurubooru/tests/model/test_user_token.py b/server/szurubooru/tests/model/test_user_token.py new file mode 100644 index 00000000..0280082e --- /dev/null +++ b/server/szurubooru/tests/model/test_user_token.py @@ -0,0 +1,14 @@ +from datetime import datetime +from szurubooru import db + + +def test_saving_user_token(user_token_factory): + user_token = user_token_factory() + db.session.add(user_token) + db.session.flush() + db.session.refresh(user_token) + assert not db.session.dirty + assert user_token.user is not None + assert user_token.token == 'dummy' + assert user_token.enabled is True + assert user_token.creation_time == datetime(1997, 1, 1) From d39439d549a003880933d730f80c56eaf2cf20fa Mon Sep 17 00:00:00 2001 From: Michael Serajnik Date: Thu, 5 Apr 2018 19:40:53 +0200 Subject: [PATCH 115/159] client/posts: fix viewport height calculation on iOS --- client/js/views/post_main_view.js | 10 ++++------ client/package-lock.json | 5 +++++ client/package.json | 1 + 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/client/js/views/post_main_view.js b/client/js/views/post_main_view.js index 3aa6c4fb..141de712 100644 --- a/client/js/views/post_main_view.js +++ b/client/js/views/post_main_view.js @@ -1,5 +1,6 @@ 'use strict'; +const iosCorrectedInnerHeight = require('ios-inner-height'); const router = require('../router.js'); const views = require('../util/views.js'); const uri = require('../util/uri.js'); @@ -26,23 +27,20 @@ class PostMainView { views.replaceContent(this._hostNode, sourceNode); views.syncScrollPosition(); - const postViewNode = document.body.querySelector('.content-wrapper'); const topNavigationNode = document.body.querySelector('#top-navigation'); - const margin = ( - postViewNode.getBoundingClientRect().top - - topNavigationNode.getBoundingClientRect().height); - this._postContentControl = new PostContentControl( postContainerNode, ctx.post, () => { + const margin = sidebarNode.getBoundingClientRect().left; + return [ window.innerWidth - postContainerNode.getBoundingClientRect().left - margin, - window.innerHeight - + iosCorrectedInnerHeight() - topNavigationNode.getBoundingClientRect().height - margin * 2, ]; diff --git a/client/package-lock.json b/client/package-lock.json index d54d15bc..f15845e6 100644 --- a/client/package-lock.json +++ b/client/package-lock.json @@ -1482,6 +1482,11 @@ "loose-envify": "1.3.1" } }, + "ios-inner-height": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/ios-inner-height/-/ios-inner-height-1.0.3.tgz", + "integrity": "sha512-GayJWoFxYHDx/gkfz4nIxNdsqB3nAJQHKV5pDBvig6he8+NxBSYxN+D7oarbqZfW2p6uera3q9NDr4Jgdafiog==" + }, "is-buffer": { "version": "1.1.5", "resolved": "https://registry.npmjs.org/is-buffer/-/is-buffer-1.1.5.tgz", diff --git a/client/package.json b/client/package.json index c482147e..e7fac031 100644 --- a/client/package.json +++ b/client/package.json @@ -16,6 +16,7 @@ "font-awesome": "^4.6.1", "glob": "^7.1.2", "html-minifier": "^1.3.1", + "ios-inner-height": "^1.0.3", "js-cookie": "^2.2.0", "js-yaml": "^3.10.0", "marked": "^0.3.9", From 2bf361c64a46d575533f68f8ced88e037f887335 Mon Sep 17 00:00:00 2001 From: Nesswit Date: Tue, 22 May 2018 02:51:38 +0900 Subject: [PATCH 116/159] client/posts: fix upload error caused by anonymous node Anonymous node does not exist in view when a user without anonymous upload permission tries to post upload. So in this case we should check for the existence of anonymousNode first. --- client/js/views/post_upload_view.js | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/client/js/views/post_upload_view.js b/client/js/views/post_upload_view.js index b64fe260..b4d0c409 100644 --- a/client/js/views/post_upload_view.js +++ b/client/js/views/post_upload_view.js @@ -283,8 +283,10 @@ class PostUploadView extends events.EventTarget { uploadable.safety = safetyNode.value; } - uploadable.anonymous = - rowNode.querySelector('.anonymous input').checked; + const anonymousNode = rowNode.querySelector('.anonymous input:checked'); + if (anonymousNode) { + uploadable.anonymous = true; + } uploadable.flags = []; if (rowNode.querySelector('.loop-video input:checked')) { From 3972b902d840e54bd4f8dd47a9d13bc20ee22c65 Mon Sep 17 00:00:00 2001 From: Shyam Sunder Date: Mon, 25 Jun 2018 10:47:20 -0400 Subject: [PATCH 117/159] client: fetch configurations from server at runtime Permissions, regex filters, app title, email info, and safety now fetched using server's Info API --- client/js/api.js | 48 ++++++++++- client/js/controllers/home_controller.js | 2 +- client/js/controllers/post_list_controller.js | 11 ++- .../js/controllers/post_upload_controller.js | 3 +- .../controllers/top_navigation_controller.js | 17 ++-- client/js/controllers/user_controller.js | 1 - .../js/controls/post_edit_sidebar_control.js | 3 +- .../controls/post_readonly_sidebar_control.js | 2 +- client/js/main.js | 83 ++++++++++--------- client/js/models/post_list.js | 2 +- client/js/models/top_navigation.js | 8 +- client/js/views/help_view.js | 4 +- client/js/views/login_view.js | 12 +-- client/js/views/not_found_view.js | 1 - client/js/views/password_reset_view.js | 6 +- client/js/views/post_merge_view.js | 1 - client/js/views/registration_view.js | 6 +- client/js/views/tag_edit_view.js | 4 +- client/js/views/tag_merge_view.js | 4 +- client/js/views/user_edit_view.js | 6 +- server/szurubooru/api/info_api.py | 4 + 21 files changed, 137 insertions(+), 91 deletions(-) diff --git a/client/js/api.js b/client/js/api.js index 3623045b..ba06f199 100644 --- a/client/js/api.js +++ b/client/js/api.js @@ -8,6 +8,7 @@ const progress = require('./util/progress.js'); const uri = require('./util/uri.js'); let fileTokens = {}; +let remoteConfig = null; class Api extends events.EventTarget { constructor() { @@ -65,14 +66,53 @@ class Api extends events.EventTarget { return this._wrappedRequest(url, request.delete, data, {}, options); } + fetchConfig() { + if (remoteConfig === null) { + return this.get(uri.formatApiLink('info')) + .then(response => { + remoteConfig = response.config; + }); + } else { + return Promise.resolve(); + } + } + + getName() { + return remoteConfig.name; + } + + getTagNameRegex() { + return remoteConfig.tagNameRegex; + } + + getPasswordRegex() { + return remoteConfig.passwordRegex; + } + + getUserNameRegex() { + return remoteConfig.userNameRegex; + } + + getContactEmail() { + return remoteConfig.contactEmail; + } + + canSendMails() { + return !!remoteConfig.canSendMails; + } + + safetyEnabled() { + return !!remoteConfig.enableSafety; + } + hasPrivilege(lookup) { let minViableRank = null; - for (let privilege of Object.keys(config.privileges)) { - if (!privilege.startsWith(lookup)) { + for (let p of Object.keys(remoteConfig.privileges)) { + if (!p.startsWith(lookup)) { continue; } - const rankName = config.privileges[privilege]; - const rankIndex = this.allRanks.indexOf(rankName); + const rankIndex = this.allRanks.indexOf( + remoteConfig.privileges[p]); if (minViableRank === null || rankIndex < minViableRank) { minViableRank = rankIndex; } diff --git a/client/js/controllers/home_controller.js b/client/js/controllers/home_controller.js index b7590a00..cc22ca95 100644 --- a/client/js/controllers/home_controller.js +++ b/client/js/controllers/home_controller.js @@ -12,7 +12,7 @@ class HomeController { topNavigation.setTitle('Home'); this._homeView = new HomeView({ - name: config.name, + name: api.getName(), version: config.meta.version, buildDate: config.meta.buildDate, canListSnapshots: api.hasPrivilege('snapshots:list'), diff --git a/client/js/controllers/post_list_controller.js b/client/js/controllers/post_list_controller.js index f60cb8b7..fd1adfea 100644 --- a/client/js/controllers/post_list_controller.js +++ b/client/js/controllers/post_list_controller.js @@ -1,6 +1,5 @@ 'use strict'; -const config = require('../config.js'); const router = require('../router.js'); const api = require('../api.js'); const settings = require('../models/settings.js'); @@ -33,9 +32,9 @@ class PostListController { this._headerView = new PostsHeaderView({ hostNode: this._pageController.view.pageHeaderHolderNode, parameters: ctx.parameters, - enableSafety: config.enableSafety, - canBulkEditTags: api.hasPrivilege('posts:bulkEdit:tags'), - canBulkEditSafety: api.hasPrivilege('posts:bulkEdit:safety'), + enableSafety: api.safetyEnabled(), + canBulkEditTags: api.hasPrivilege('posts:bulk-edit:tags'), + canBulkEditSafety: api.hasPrivilege('posts:bulk-edit:safety'), bulkEdit: { tags: this._bulkEditTags }, @@ -97,9 +96,9 @@ class PostListController { pageRenderer: pageCtx => { Object.assign(pageCtx, { canViewPosts: api.hasPrivilege('posts:view'), - canBulkEditTags: api.hasPrivilege('posts:bulkEdit:tags'), + canBulkEditTags: api.hasPrivilege('posts:bulk-edit:tags'), canBulkEditSafety: - api.hasPrivilege('posts:bulkEdit:safety'), + api.hasPrivilege('posts:bulk-edit:safety'), bulkEdit: { tags: this._bulkEditTags, }, diff --git a/client/js/controllers/post_upload_controller.js b/client/js/controllers/post_upload_controller.js index 9dfdb4af..ccd6f94c 100644 --- a/client/js/controllers/post_upload_controller.js +++ b/client/js/controllers/post_upload_controller.js @@ -1,7 +1,6 @@ 'use strict'; const api = require('../api.js'); -const config = require('../config.js'); const router = require('../router.js'); const uri = require('../util/uri.js'); const misc = require('../util/misc.js'); @@ -31,7 +30,7 @@ class PostUploadController { this._view = new PostUploadView({ canUploadAnonymously: api.hasPrivilege('posts:create:anonymous'), canViewPosts: api.hasPrivilege('posts:view'), - enableSafety: config.enableSafety, + enableSafety: api.safetyEnabled(), }); this._view.addEventListener('change', e => this._evtChange(e)); this._view.addEventListener('submit', e => this._evtSubmit(e)); diff --git a/client/js/controllers/top_navigation_controller.js b/client/js/controllers/top_navigation_controller.js index 0f56ddc0..fdf8117a 100644 --- a/client/js/controllers/top_navigation_controller.js +++ b/client/js/controllers/top_navigation_controller.js @@ -1,21 +1,22 @@ 'use strict'; const api = require('../api.js'); -const config = require('../config.js'); const topNavigation = require('../models/top_navigation.js'); const TopNavigationView = require('../views/top_navigation_view.js'); class TopNavigationController { constructor() { - this._topNavigationView = new TopNavigationView(); + api.fetchConfig().then(() => { + this._topNavigationView = new TopNavigationView(); - topNavigation.addEventListener( - 'activate', e => this._evtActivate(e)); + topNavigation.addEventListener( + 'activate', e => this._evtActivate(e)); - api.addEventListener('login', e => this._evtAuthChange(e)); - api.addEventListener('logout', e => this._evtAuthChange(e)); + api.addEventListener('login', e => this._evtAuthChange(e)); + api.addEventListener('logout', e => this._evtAuthChange(e)); - this._render(); + this._render(); + }); } _evtAuthChange(e) { @@ -65,7 +66,7 @@ class TopNavigationController { this._updateNavigationFromPrivileges(); this._topNavigationView.render({ items: topNavigation.getAll(), - name: config.name + name: api.getName() }); this._topNavigationView.activate( topNavigation.activeItem ? topNavigation.activeItem.key : ''); diff --git a/client/js/controllers/user_controller.js b/client/js/controllers/user_controller.js index d042e41f..48989af3 100644 --- a/client/js/controllers/user_controller.js +++ b/client/js/controllers/user_controller.js @@ -4,7 +4,6 @@ const router = require('../router.js'); const api = require('../api.js'); const uri = require('../util/uri.js'); const misc = require('../util/misc.js'); -const config = require('../config.js'); const views = require('../util/views.js'); const User = require('../models/user.js'); const UserToken = require('../models/user_token.js'); diff --git a/client/js/controls/post_edit_sidebar_control.js b/client/js/controls/post_edit_sidebar_control.js index 38dded28..dd8b66da 100644 --- a/client/js/controls/post_edit_sidebar_control.js +++ b/client/js/controls/post_edit_sidebar_control.js @@ -1,7 +1,6 @@ 'use strict'; const api = require('../api.js'); -const config = require('../config.js'); const events = require('../events.js'); const misc = require('../util/misc.js'); const views = require('../util/views.js'); @@ -26,7 +25,7 @@ class PostEditSidebarControl extends events.EventTarget { views.replaceContent(this._hostNode, template({ post: this._post, - enableSafety: config.enableSafety, + enableSafety: api.safetyEnabled(), hasClipboard: document.queryCommandSupported('copy'), canEditPostSafety: api.hasPrivilege('posts:edit:safety'), canEditPostSource: api.hasPrivilege('posts:edit:source'), diff --git a/client/js/controls/post_readonly_sidebar_control.js b/client/js/controls/post_readonly_sidebar_control.js index bb51e69d..579a38e6 100644 --- a/client/js/controls/post_readonly_sidebar_control.js +++ b/client/js/controls/post_readonly_sidebar_control.js @@ -21,7 +21,7 @@ class PostReadonlySidebarControl extends events.EventTarget { views.replaceContent(this._hostNode, template({ post: this._post, - enableSafety: config.enableSafety, + enableSafety: api.safetyEnabled(), canListPosts: api.hasPrivilege('posts:list'), canEditPosts: api.hasPrivilege('posts:edit'), canViewTags: api.hasPrivilege('tags:view'), diff --git a/client/js/main.js b/client/js/main.js index 71284f8e..5ddf26fe 100644 --- a/client/js/main.js +++ b/client/js/main.js @@ -26,46 +26,51 @@ router.enter( next(); }); -// register controller routes -let controllers = []; -controllers.push(require('./controllers/home_controller.js')); -controllers.push(require('./controllers/help_controller.js')); -controllers.push(require('./controllers/auth_controller.js')); -controllers.push(require('./controllers/password_reset_controller.js')); -controllers.push(require('./controllers/comments_controller.js')); -controllers.push(require('./controllers/snapshots_controller.js')); -controllers.push(require('./controllers/post_detail_controller.js')); -controllers.push(require('./controllers/post_main_controller.js')); -controllers.push(require('./controllers/post_list_controller.js')); -controllers.push(require('./controllers/post_upload_controller.js')); -controllers.push(require('./controllers/tag_controller.js')); -controllers.push(require('./controllers/tag_list_controller.js')); -controllers.push(require('./controllers/tag_categories_controller.js')); -controllers.push(require('./controllers/settings_controller.js')); -controllers.push(require('./controllers/user_controller.js')); -controllers.push(require('./controllers/user_list_controller.js')); -controllers.push(require('./controllers/user_registration_controller.js')); - -// 404 controller needs to be registered last -controllers.push(require('./controllers/not_found_controller.js')); - -for (let controller of controllers) { - controller(router); -} - const tags = require('./tags.js'); const api = require('./api.js'); tags.refreshCategoryColorMap(); // we don't care about errors -api.loginFromCookies().then(() => { - router.start(); - }, error => { - if (window.location.href.indexOf('login') !== -1) { - api.forget(); + +api.fetchConfig().then(() => { + // register controller routes + let controllers = []; + controllers.push(require('./controllers/home_controller.js')); + controllers.push(require('./controllers/help_controller.js')); + controllers.push(require('./controllers/auth_controller.js')); + controllers.push(require('./controllers/password_reset_controller.js')); + controllers.push(require('./controllers/comments_controller.js')); + controllers.push(require('./controllers/snapshots_controller.js')); + controllers.push(require('./controllers/post_detail_controller.js')); + controllers.push(require('./controllers/post_main_controller.js')); + controllers.push(require('./controllers/post_list_controller.js')); + controllers.push(require('./controllers/post_upload_controller.js')); + controllers.push(require('./controllers/tag_controller.js')); + controllers.push(require('./controllers/tag_list_controller.js')); + controllers.push(require('./controllers/tag_categories_controller.js')); + controllers.push(require('./controllers/settings_controller.js')); + controllers.push(require('./controllers/user_controller.js')); + controllers.push(require('./controllers/user_list_controller.js')); + controllers.push(require('./controllers/user_registration_controller.js')); + + // 404 controller needs to be registered last + controllers.push(require('./controllers/not_found_controller.js')); + + for (let controller of controllers) { + controller(router); + } +}, error => { + window.alert('Could not fetch basic configuration from server'); +}).then(() => { + api.loginFromCookies().then(() => { router.start(); - } else { - const ctx = router.start('/'); - ctx.controller.showError( - 'An error happened while trying to log you in: ' + - error.message); - } - }); + }, error => { + if (window.location.href.indexOf('login') !== -1) { + api.forget(); + router.start(); + } else { + const ctx = router.start('/'); + ctx.controller.showError( + 'An error happened while trying to log you in: ' + + error.message); + } + }); +}); diff --git a/client/js/models/post_list.js b/client/js/models/post_list.js index 2e3c6fca..2bfd056b 100644 --- a/client/js/models/post_list.js +++ b/client/js/models/post_list.js @@ -37,7 +37,7 @@ class PostList extends AbstractList { static _decorateSearchQuery(text) { const browsingSettings = settings.get(); const disabledSafety = []; - if (config.enableSafety) { + if (api.safetyEnabled()) { for (let key of Object.keys(browsingSettings.listPosts)) { if (browsingSettings.listPosts[key] === false) { disabledSafety.push(key); diff --git a/client/js/models/top_navigation.js b/client/js/models/top_navigation.js index ebb3753e..a7c726db 100644 --- a/client/js/models/top_navigation.js +++ b/client/js/models/top_navigation.js @@ -1,7 +1,7 @@ 'use strict'; -const config = require('../config.js'); const events = require('../events.js'); +const api = require('../api.js'); class TopNavigationItem { constructor(accessKey, title, url, available, imageUrl) { @@ -53,8 +53,10 @@ class TopNavigation extends events.EventTarget { } setTitle(title) { - document.oldTitle = null; - document.title = config.name + (title ? (' – ' + title) : ''); + api.fetchConfig().then(() => { + document.oldTitle = null; + document.title = api.getName() + (title ? (' – ' + title) : ''); + }); } showAll() { diff --git a/client/js/views/help_view.js b/client/js/views/help_view.js index 4938b231..11f6a39a 100644 --- a/client/js/views/help_view.js +++ b/client/js/views/help_view.js @@ -1,6 +1,6 @@ 'use strict'; -const config = require('../config.js'); +const api = require('../api.js'); const views = require('../util/views.js'); const template = views.getTemplate('help'); @@ -26,7 +26,7 @@ class HelpView { const sourceNode = template(); const ctx = { - name: config.name, + name: api.getName(), }; section = section || 'about'; diff --git a/client/js/views/login_view.js b/client/js/views/login_view.js index 7d97982c..2c05332c 100644 --- a/client/js/views/login_view.js +++ b/client/js/views/login_view.js @@ -1,7 +1,7 @@ 'use strict'; -const config = require('../config.js'); const events = require('../events.js'); +const api = require('../api.js'); const views = require('../util/views.js'); const template = views.getTemplate('login'); @@ -12,15 +12,15 @@ class LoginView extends events.EventTarget { this._hostNode = document.getElementById('content-holder'); views.replaceContent(this._hostNode, template({ - userNamePattern: config.userNameRegex, - passwordPattern: config.passwordRegex, - canSendMails: config.canSendMails, + userNamePattern: api.getUserNameRegex(), + passwordPattern: api.getPasswordRegex(), + canSendMails: api.canSendMails(), })); views.syncScrollPosition(); views.decorateValidator(this._formNode); - this._userNameInputNode.setAttribute('pattern', config.userNameRegex); - this._passwordInputNode.setAttribute('pattern', config.passwordRegex); + this._userNameInputNode.setAttribute('pattern', api.getUserNameRegex()); + this._passwordInputNode.setAttribute('pattern', api.getPasswordRegex()); this._formNode.addEventListener('submit', e => { e.preventDefault(); this.dispatchEvent(new CustomEvent('submit', { diff --git a/client/js/views/not_found_view.js b/client/js/views/not_found_view.js index 8b5a1039..487613b5 100644 --- a/client/js/views/not_found_view.js +++ b/client/js/views/not_found_view.js @@ -1,6 +1,5 @@ 'use strict'; -const config = require('../config.js'); const views = require('../util/views.js'); const template = views.getTemplate('not-found'); diff --git a/client/js/views/password_reset_view.js b/client/js/views/password_reset_view.js index 25409c77..685fe5a0 100644 --- a/client/js/views/password_reset_view.js +++ b/client/js/views/password_reset_view.js @@ -1,7 +1,7 @@ 'use strict'; -const config = require('../config.js'); const events = require('../events.js'); +const api = require('../api.js'); const views = require('../util/views.js'); const template = views.getTemplate('password-reset'); @@ -12,8 +12,8 @@ class PasswordResetView extends events.EventTarget { this._hostNode = document.getElementById('content-holder'); views.replaceContent(this._hostNode, template({ - canSendMails: config.canSendMails, - contactEmail: config.contactEmail, + canSendMails: api.canSendMails(), + contactEmail: api.getContactEmail(), })); views.syncScrollPosition(); diff --git a/client/js/views/post_merge_view.js b/client/js/views/post_merge_view.js index 04ed21da..3e987b36 100644 --- a/client/js/views/post_merge_view.js +++ b/client/js/views/post_merge_view.js @@ -1,6 +1,5 @@ 'use strict'; -const config = require('../config.js'); const events = require('../events.js'); const views = require('../util/views.js'); diff --git a/client/js/views/registration_view.js b/client/js/views/registration_view.js index fb924c2f..48034ddf 100644 --- a/client/js/views/registration_view.js +++ b/client/js/views/registration_view.js @@ -1,7 +1,7 @@ 'use strict'; -const config = require('../config.js'); const events = require('../events.js'); +const api = require('../api.js'); const views = require('../util/views.js'); const template = views.getTemplate('user-registration'); @@ -11,8 +11,8 @@ class RegistrationView extends events.EventTarget { super(); this._hostNode = document.getElementById('content-holder'); views.replaceContent(this._hostNode, template({ - userNamePattern: config.userNameRegex, - passwordPattern: config.passwordRegex, + userNamePattern: api.getUserNameRegex(), + passwordPattern: api.getPasswordRegex(), })); views.syncScrollPosition(); views.decorateValidator(this._formNode); diff --git a/client/js/views/tag_edit_view.js b/client/js/views/tag_edit_view.js index 77c7cefc..5b517d46 100644 --- a/client/js/views/tag_edit_view.js +++ b/client/js/views/tag_edit_view.js @@ -1,7 +1,7 @@ 'use strict'; -const config = require('../config.js'); const events = require('../events.js'); +const api = require('../api.js'); const misc = require('../util/misc.js'); const views = require('../util/views.js'); const TagInputControl = require('../controls/tag_input_control.js'); @@ -64,7 +64,7 @@ class TagEditView extends events.EventTarget { } _evtNameInput(e) { - const regex = new RegExp(config.tagNameRegex); + const regex = new RegExp(api.getTagNameRegex()); const list = misc.splitByWhitespace(this._namesFieldNode.value); if (!list.length) { diff --git a/client/js/views/tag_merge_view.js b/client/js/views/tag_merge_view.js index 12286800..c975a50e 100644 --- a/client/js/views/tag_merge_view.js +++ b/client/js/views/tag_merge_view.js @@ -1,7 +1,7 @@ 'use strict'; -const config = require('../config.js'); const events = require('../events.js'); +const api = require('../api.js'); const views = require('../util/views.js'); const TagAutoCompleteControl = require('../controls/tag_auto_complete_control.js'); @@ -14,7 +14,7 @@ class TagMergeView extends events.EventTarget { this._tag = ctx.tag; this._hostNode = ctx.hostNode; - ctx.tagNamePattern = config.tagNameRegex; + ctx.tagNamePattern = api.getTagNameRegex(); views.replaceContent(this._hostNode, template(ctx)); views.decorateValidator(this._formNode); diff --git a/client/js/views/user_edit_view.js b/client/js/views/user_edit_view.js index 28208ead..8a6f1a4a 100644 --- a/client/js/views/user_edit_view.js +++ b/client/js/views/user_edit_view.js @@ -1,7 +1,7 @@ 'use strict'; -const config = require('../config.js'); const events = require('../events.js'); +const api = require('../api.js'); const views = require('../util/views.js'); const FileDropperControl = require('../controls/file_dropper_control.js'); @@ -11,8 +11,8 @@ class UserEditView extends events.EventTarget { constructor(ctx) { super(); - ctx.userNamePattern = config.userNameRegex + /|^$/.source; - ctx.passwordPattern = config.passwordRegex + /|^$/.source; + ctx.userNamePattern = api.getUserNameRegex() + /|^$/.source; + ctx.passwordPattern = api.getPasswordRegex() + /|^$/.source; this._user = ctx.user; this._hostNode = ctx.hostNode; diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index 19b42cb8..8f8c9ea4 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -35,11 +35,15 @@ def get_info( 'diskUsage': _get_disk_usage(), 'serverTime': datetime.utcnow(), 'config': { + 'name': config.config['name'], 'userNameRegex': config.config['user_name_regex'], 'passwordRegex': config.config['password_regex'], 'tagNameRegex': config.config['tag_name_regex'], 'tagCategoryNameRegex': config.config['tag_category_name_regex'], 'defaultUserRank': config.config['default_rank'], + 'enableSafety': config.config['enable_safety'], + 'contactEmail': config.config['contactEmail'], + 'canSendMails': bool(config.config['smtp']['host']), 'privileges': util.snake_case_to_lower_camel_case_keys( config.config['privileges']), From 90503566b53ff1255ac331f0343d530dfed5a6c7 Mon Sep 17 00:00:00 2001 From: nothink <33431+nothink@users.noreply.github.com> Date: Mon, 2 Jul 2018 11:35:08 +0900 Subject: [PATCH 118/159] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 3b5c4e44..75c7941c 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![Build Status](https://travis-ci.com/nothink/szurubooru.svg?branch=travis)](https://travis-ci.com/nothink/szurubooru) + # szurubooru Szurubooru is an image board engine inspired by services such as Danbooru, From 8d522c0b26f560a793a15e063c129df8969e7de5 Mon Sep 17 00:00:00 2001 From: nothink <33431+nothink@users.noreply.github.com> Date: Mon, 2 Jul 2018 14:09:16 +0900 Subject: [PATCH 119/159] Update README.md --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 75c7941c..3b5c4e44 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -[![Build Status](https://travis-ci.com/nothink/szurubooru.svg?branch=travis)](https://travis-ci.com/nothink/szurubooru) - # szurubooru Szurubooru is an image board engine inspired by services such as Danbooru, From 60ab9246c66cba422ccd79592173a0791817e45b Mon Sep 17 00:00:00 2001 From: Shyam Sunder Date: Thu, 5 Jul 2018 19:25:08 -0400 Subject: [PATCH 120/159] client: improved build.js, use relative links * Removed unnecessary require('config.js') calls * 'markdown.js' now uses rel. links in EntityPermalinkWrapper * 'password_reset.py' now generates rel. links * Removed 'Base URL' config parameter * Removed 'API URL' config parameter * 'build.js' no longer reads/requires config.yaml * Updated documentation * Removed unnecessary node packages used in 'build.js' abandon api_url parameter --- INSTALL.md | 30 +- client/build.js | 82 +- client/html/index.htm | 2 +- client/html/post_readonly_sidebar.tpl | 4 +- client/js/api.js | 3 +- .../controls/post_readonly_sidebar_control.js | 1 - client/js/models/post.js | 2 + client/js/models/post_list.js | 1 - client/js/util/markdown.js | 10 +- client/package-lock.json | 1140 ++++++++--------- client/package.json | 4 - config.yaml.dist | 5 +- server/szurubooru/api/password_reset_api.py | 3 +- server/szurubooru/facade.py | 2 +- 14 files changed, 595 insertions(+), 694 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index a88152ca..90313ddd 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -84,29 +84,32 @@ user@host:szuru/server$ source python_modules/bin/activate # enters the sandbox ### Preparing `szurubooru` for first run -1. Configure things: +1. Compile the frontend: ```console + user@host:szuru$ cd client + user@host:szuru/client$ node build.js + ``` + + You can include the flags `--no-transpile` to disable the JavaScript + transpiler, which provides compatibility with older browsers, and + `--debug` to generate JS source mappings. + +2. Configure things: + + ```console + user@host:szuru/client$ cd .. user@host:szuru$ cp config.yaml.dist config.yaml user@host:szuru$ vim config.yaml ``` Pay extra attention to these fields: - - base URL, - - API URL, - data directory, - data URL, - database, - the `smtp` section. -2. Compile the frontend: - - ```console - user@host:szuru$ cd client - user@host:szuru/client$ npm run build - ``` - 3. Upgrade the database: ```console @@ -140,6 +143,11 @@ meant to be exposed directly to the end users. The API should be exposed using WSGI server such as `waitress`, `gunicorn` or similar. Other configurations might be possible but I didn't pursue them. +API calls are made to the relative URL `/api/`. Your HTTP server should be +configured to proxy this URL format to the WSGI server. Some users may prefer +to use a dedicated reverse proxy for this, to incorporate additional features +such as load balancing and SSL. + Note that the API URL in the virtual host configuration needs to be the same as the one in the `config.yaml`, so that client knows how to access the backend! @@ -177,8 +185,6 @@ server { **`config.yaml`**: ```yaml -api_url: 'http://example.com/api/' -base_url: 'http://example.com/' data_url: 'http://example.com/data/' data_dir: '/srv/www/booru/client/public/data' ``` diff --git a/client/build.js b/client/build.js index c30f33f4..120daa24 100644 --- a/client/build.js +++ b/client/build.js @@ -5,20 +5,6 @@ const glob = require('glob'); const path = require('path'); const util = require('util'); const execSync = require('child_process').execSync; -const camelcase = require('camelcase'); - -function convertKeysToCamelCase(input) { - let result = {}; - Object.keys(input).map((key, _) => { - const value = input[key]; - if (value !== null && value.constructor == Object) { - result[camelcase(key)] = convertKeysToCamelCase(value); - } else { - result[camelcase(key)] = value; - } - }); - return result; -} function readTextFile(path) { return fs.readFileSync(path, 'utf-8'); @@ -29,37 +15,27 @@ function writeFile(path, content) { } function getVersion() { - return execSync('git describe --always --dirty --long --tags') - .toString() - .trim(); + let build_info = process.env.BUILD_INFO; + if (build_info) { + return build_info.trim(); + } else { + try { + build_info = execSync('git describe --always --dirty --long --tags') + .toString(); + } catch (e) { + console.warn('Cannot find build version'); + return 'unknown'; + } + return build_info.trim(); + } } function getConfig() { - const yaml = require('js-yaml'); - const merge = require('merge'); - const camelcaseKeys = require('camelcase-keys'); - - function parseConfigFile(path) { - let result = yaml.load(readTextFile(path, 'utf-8')); - return convertKeysToCamelCase(result); - } - - let config = parseConfigFile('../config.yaml.dist'); - - try { - const localConfig = parseConfigFile('../config.yaml'); - config = merge.recursive(config, localConfig); - } catch (e) { - console.warn('Local config does not exist, ignoring'); - } - - config.canSendMails = !!config.smtp.host; - delete config.secret; - delete config.smtp; - delete config.database; - config.meta = { - version: getVersion(), - buildDate: new Date().toUTCString(), + let config = { + meta: { + version: getVersion(), + buildDate: new Date().toUTCString() + } }; return config; @@ -85,15 +61,11 @@ function minifyHtml(html) { }).trim(); } -function bundleHtml(config) { +function bundleHtml() { const underscore = require('underscore'); const babelify = require('babelify'); const baseHtml = readTextFile('./html/index.htm', 'utf-8'); - const finalHtml = baseHtml - .replace( - /()(.*)(<\/title>)/, - util.format('$1%s$3', config.name)); - writeFile('./public/index.htm', minifyHtml(finalHtml)); + writeFile('./public/index.htm', minifyHtml(baseHtml)); glob('./html/**/*.tpl', {}, (er, files) => { let compiledTemplateJs = '\'use strict\'\n'; @@ -143,7 +115,7 @@ function bundleCss() { }); } -function bundleJs(config) { +function bundleJs() { const browserify = require('browserify'); const external = [ 'underscore', @@ -170,7 +142,7 @@ function bundleJs(config) { for (let lib of external) { b.require(lib); } - if (config.transpile) { + if (!process.argv.includes('--no-transpile')) { b.add(require.resolve('babel-polyfill')); } writeJsBundle( @@ -179,15 +151,15 @@ function bundleJs(config) { if (!process.argv.includes('--no-app-js')) { let outputFile = fs.createWriteStream('./public/js/app.min.js'); - let b = browserify({debug: config.debug}); - if (config.transpile) { + let b = browserify({debug: process.argv.includes('--debug')}); + if (!process.argv.includes('--no-transpile')) { b = b.transform('babelify'); } writeJsBundle( b.external(external).add(files), './public/js/app.min.js', 'Bundled app JS', - !config.debug); + !process.argv.includes('--debug')); } }); } @@ -217,11 +189,11 @@ const config = getConfig(); bundleConfig(config); bundleBinaryAssets(); if (!process.argv.includes('--no-html')) { - bundleHtml(config); + bundleHtml(); } if (!process.argv.includes('--no-css')) { bundleCss(); } if (!process.argv.includes('--no-js')) { - bundleJs(config); + bundleJs(); } diff --git a/client/html/index.htm b/client/html/index.htm index f20ad461..5df8e6dd 100644 --- a/client/html/index.htm +++ b/client/html/index.htm @@ -3,7 +3,7 @@ <head> <meta charset='utf-8'/> <meta name='viewport' content='width=device-width, initial-scale=1, maximum-scale=1'> - <title><!-- configured in the config file --> + Loading... diff --git a/client/html/post_readonly_sidebar.tpl b/client/html/post_readonly_sidebar.tpl index 51ba9b50..1209cf16 100644 --- a/client/html/post_readonly_sidebar.tpl +++ b/client/html/post_readonly_sidebar.tpl @@ -36,8 +36,8 @@