diff --git a/app/__init__.py b/app/__init__.py index e92cd70..a1b9e2b 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -9,6 +9,7 @@ from flask_mail import Mail from flask_bootstrap import Bootstrap from flask_moment import Moment from flask_babel import Babel, lazy_gettext as _l +from elasticsearch import Elasticsearch from config import Config db = SQLAlchemy() @@ -33,6 +34,8 @@ def create_app(config_class=Config): bootstrap.init_app(app) moment.init_app(app) babel.init_app(app) + app.elasticsearch = Elasticsearch([app.config['ELASTICSEARCH_URL']]) \ + if app.config['ELASTICSEARCH_URL'] else None from app.errors import bp as errors_bp app.register_blueprint(errors_bp) diff --git a/app/main/forms.py b/app/main/forms.py index a688ba7..cb587db 100644 --- a/app/main/forms.py +++ b/app/main/forms.py @@ -27,3 +27,13 @@ class PostForm(FlaskForm): post = TextAreaField(_l('Say something'), validators=[DataRequired()]) submit = SubmitField(_l('Submit')) + +class SearchForm(FlaskForm): + q = StringField(_l('Search'), validators=[DataRequired()]) + + def __init__(self, *args, **kwargs): + if 'formdata' not in kwargs: + kwargs['formdata'] = request.args + if 'csrf_enabled' not in kwargs: + kwargs['csrf_enabled'] = False + super(SearchForm, self).__init__(*args, **kwargs) diff --git a/app/main/routes.py b/app/main/routes.py index 8a31b20..ec4a07b 100644 --- a/app/main/routes.py +++ b/app/main/routes.py @@ -5,7 +5,7 @@ from flask_login import current_user, login_required from flask_babel import _, get_locale from guess_language import guess_language from app import db -from app.main.forms import EditProfileForm, PostForm +from app.main.forms import EditProfileForm, PostForm, SearchForm from app.models import User, Post from app.translate import translate from app.main import bp @@ -16,6 +16,7 @@ def before_request(): if current_user.is_authenticated: current_user.last_seen = datetime.utcnow() db.session.commit() + g.search_form = SearchForm() g.locale = str(get_locale()) @@ -132,3 +133,18 @@ def translate_text(): request.form['source_language'], request.form['dest_language'])}) + +@bp.route('/search') +@login_required +def search(): + if not g.search_form.validate(): + return redirect(url_for('main.explore')) + page = request.args.get('page', 1, type=int) + posts, total = Post.search(g.search_form.q.data, page, + current_app.config['POSTS_PER_PAGE']) + next_url = url_for('main.search', q=g.search_form.q.data, page=page + 1) \ + if total > page * current_app.config['POSTS_PER_PAGE'] else None + prev_url = url_for('main.search', q=g.search_form.q.data, page=page - 1) \ + if page > 1 else None + return render_template('search.html', title=_('Search'), posts=posts, + next_url=next_url, prev_url=prev_url) diff --git a/app/models.py b/app/models.py index 70ad4a1..ecd6b20 100644 --- a/app/models.py +++ b/app/models.py @@ -6,6 +6,50 @@ from flask_login import UserMixin from werkzeug.security import generate_password_hash, check_password_hash import jwt from app import db, login +from app.search import add_to_index, remove_from_index, query_index + + +class SearchableMixin(object): + @classmethod + def search(cls, expression, page, per_page): + ids, total = query_index(cls.__tablename__, expression, page, per_page) + if total == 0: + return cls.query.filter_by(id=0), 0 + when = [] + for i in range(len(ids)): + when.append((ids[i], i)) + return cls.query.filter(cls.id.in_(ids)).order_by( + db.case(when, value=cls.id)), total + + @classmethod + def before_commit(cls, session): + session._changes = { + 'add': list(session.new), + 'update': list(session.dirty), + 'delete': list(session.deleted) + } + + @classmethod + def after_commit(cls, session): + for obj in session._changes['add']: + if isinstance(obj, SearchableMixin): + add_to_index(obj.__tablename__, obj) + for obj in session._changes['update']: + if isinstance(obj, SearchableMixin): + add_to_index(obj.__tablename__, obj) + for obj in session._changes['delete']: + if isinstance(obj, SearchableMixin): + remove_from_index(obj.__tablename__, obj) + session._changes = None + + @classmethod + def reindex(cls): + for obj in cls.query: + add_to_index(cls.__tablename__, obj) + + +db.event.listen(db.session, 'before_commit', SearchableMixin.before_commit) +db.event.listen(db.session, 'after_commit', SearchableMixin.after_commit) followers = db.Table( @@ -83,7 +127,8 @@ def load_user(id): return User.query.get(int(id)) -class Post(db.Model): +class Post(SearchableMixin, db.Model): + __searchable__ = ['body'] id = db.Column(db.Integer, primary_key=True) body = db.Column(db.String(140)) timestamp = db.Column(db.DateTime, index=True, default=datetime.utcnow) diff --git a/app/search.py b/app/search.py new file mode 100644 index 0000000..3b939fd --- /dev/null +++ b/app/search.py @@ -0,0 +1,28 @@ +from flask import current_app + + +def add_to_index(index, model): + if not current_app.elasticsearch: + return + payload = {} + for field in model.__searchable__: + payload[field] = getattr(model, field) + current_app.elasticsearch.index(index=index, doc_type=index, id=model.id, + body=payload) + + +def remove_from_index(index, model): + if not current_app.elasticsearch: + return + current_app.elasticsearch.delete(index=index, doc_type=index, id=model.id) + + +def query_index(index, query, page, per_page): + if not current_app.elasticsearch: + return [], 0 + search = current_app.elasticsearch.search( + index=index, doc_type=index, + body={'query': {'multi_match': {'query': query, 'fields': ['*']}}, + 'from': (page - 1) * per_page, 'size': per_page}) + ids = [int(hit['_id']) for hit in search['hits']['hits']] + return ids, search['hits']['total'] diff --git a/app/templates/base.html b/app/templates/base.html index 6a54732..a985a70 100644 --- a/app/templates/base.html +++ b/app/templates/base.html @@ -21,6 +21,13 @@
  • {{ _('Home') }}
  • {{ _('Explore') }}
  • + {% if g.search_form %} + + {% endif %}