diff --git a/app/__init__.py b/app/__init__.py index a94f301..49fad10 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -8,6 +8,7 @@ from flask_login import LoginManager from flask_mail import Mail from flask_moment import Moment from flask_babel import Babel, lazy_gettext as _l +from elasticsearch import Elasticsearch from config import Config @@ -35,6 +36,8 @@ def create_app(config_class=Config): mail.init_app(app) moment.init_app(app) babel.init_app(app, locale_selector=get_locale) + app.elasticsearch = Elasticsearch([app.config['ELASTICSEARCH_URL']]) \ + if app.config['ELASTICSEARCH_URL'] else None from app.errors import bp as errors_bp app.register_blueprint(errors_bp) diff --git a/app/main/forms.py b/app/main/forms.py index 8dec298..daabe7e 100644 --- a/app/main/forms.py +++ b/app/main/forms.py @@ -31,3 +31,13 @@ class PostForm(FlaskForm): post = TextAreaField(_l('Say something'), validators=[DataRequired()]) submit = SubmitField(_l('Submit')) + +class SearchForm(FlaskForm): + q = StringField(_l('Search'), validators=[DataRequired()]) + + def __init__(self, *args, **kwargs): + if 'formdata' not in kwargs: + kwargs['formdata'] = request.args + if 'meta' not in kwargs: + kwargs['meta'] = {'csrf': False} + super(SearchForm, self).__init__(*args, **kwargs) diff --git a/app/main/routes.py b/app/main/routes.py index bccd549..ee09024 100644 --- a/app/main/routes.py +++ b/app/main/routes.py @@ -5,7 +5,7 @@ from flask_login import current_user, login_required from flask_babel import _, get_locale from langdetect import detect, LangDetectException from app import db -from app.main.forms import EditProfileForm, EmptyForm, PostForm +from app.main.forms import EditProfileForm, EmptyForm, PostForm, SearchForm from app.models import User, Post from app.translate import translate from app.main import bp @@ -16,6 +16,7 @@ def before_request(): if current_user.is_authenticated: current_user.last_seen = datetime.utcnow() db.session.commit() + g.search_form = SearchForm() g.locale = str(get_locale()) @@ -145,3 +146,19 @@ def translate_text(): return {'text': translate(data['text'], data['source_language'], data['dest_language'])} + + +@bp.route('/search') +@login_required +def search(): + if not g.search_form.validate(): + return redirect(url_for('main.explore')) + page = request.args.get('page', 1, type=int) + posts, total = Post.search(g.search_form.q.data, page, + current_app.config['POSTS_PER_PAGE']) + next_url = url_for('main.search', q=g.search_form.q.data, page=page + 1) \ + if total > page * current_app.config['POSTS_PER_PAGE'] else None + prev_url = url_for('main.search', q=g.search_form.q.data, page=page - 1) \ + if page > 1 else None + return render_template('search.html', title=_('Search'), posts=posts, + next_url=next_url, prev_url=prev_url) diff --git a/app/models.py b/app/models.py index a8a9edc..762d72d 100644 --- a/app/models.py +++ b/app/models.py @@ -6,6 +6,50 @@ from flask_login import UserMixin from werkzeug.security import generate_password_hash, check_password_hash import jwt from app import db, login +from app.search import add_to_index, remove_from_index, query_index + + +class SearchableMixin(object): + @classmethod + def search(cls, expression, page, per_page): + ids, total = query_index(cls.__tablename__, expression, page, per_page) + if total == 0: + return cls.query.filter_by(id=0), 0 + when = [] + for i in range(len(ids)): + when.append((ids[i], i)) + return cls.query.filter(cls.id.in_(ids)).order_by( + db.case(*when, value=cls.id)), total + + @classmethod + def before_commit(cls, session): + session._changes = { + 'add': list(session.new), + 'update': list(session.dirty), + 'delete': list(session.deleted) + } + + @classmethod + def after_commit(cls, session): + for obj in session._changes['add']: + if isinstance(obj, SearchableMixin): + add_to_index(obj.__tablename__, obj) + for obj in session._changes['update']: + if isinstance(obj, SearchableMixin): + add_to_index(obj.__tablename__, obj) + for obj in session._changes['delete']: + if isinstance(obj, SearchableMixin): + remove_from_index(obj.__tablename__, obj) + session._changes = None + + @classmethod + def reindex(cls): + for obj in cls.query: + add_to_index(cls.__tablename__, obj) + + +db.event.listen(db.session, 'before_commit', SearchableMixin.before_commit) +db.event.listen(db.session, 'after_commit', SearchableMixin.after_commit) followers = db.Table( @@ -87,7 +131,8 @@ def load_user(id): return User.query.get(int(id)) -class Post(db.Model): +class Post(SearchableMixin, db.Model): + __searchable__ = ['body'] id = db.Column(db.Integer, primary_key=True) body = db.Column(db.String(140)) timestamp = db.Column(db.DateTime, index=True, default=datetime.utcnow) diff --git a/app/search.py b/app/search.py new file mode 100644 index 0000000..51c5c96 --- /dev/null +++ b/app/search.py @@ -0,0 +1,28 @@ +from flask import current_app + + +def add_to_index(index, model): + if not current_app.elasticsearch: + return + payload = {} + for field in model.__searchable__: + payload[field] = getattr(model, field) + current_app.elasticsearch.index(index=index, id=model.id, document=payload) + + +def remove_from_index(index, model): + if not current_app.elasticsearch: + return + current_app.elasticsearch.delete(index=index, id=model.id) + + +def query_index(index, query, page, per_page): + if not current_app.elasticsearch: + return [], 0 + search = current_app.elasticsearch.search( + index=index, + query={'multi_match': {'query': query, 'fields': ['*']}}, + from_=(page - 1) * per_page, + size=per_page) + ids = [int(hit['_id']) for hit in search['hits']['hits']] + return ids, search['hits']['total']['value'] diff --git a/app/templates/base.html b/app/templates/base.html index cb6e18a..5a66063 100644 --- a/app/templates/base.html +++ b/app/templates/base.html @@ -25,6 +25,13 @@ + {% if g.search_form %} + + {% endif %}