followers

This commit is contained in:
Miguel Grinberg 2012-12-16 00:34:24 -08:00
parent 8c091d7add
commit 47e65873a9
5 changed files with 171 additions and 3 deletions

View File

@ -4,6 +4,11 @@ from app import db
ROLE_USER = 0 ROLE_USER = 0
ROLE_ADMIN = 1 ROLE_ADMIN = 1
followers = db.Table('followers',
db.Column('follower_id', db.Integer, db.ForeignKey('user.id')),
db.Column('followed_id', db.Integer, db.ForeignKey('user.id'))
)
class User(db.Model): class User(db.Model):
id = db.Column(db.Integer, primary_key = True) id = db.Column(db.Integer, primary_key = True)
nickname = db.Column(db.String(64), unique = True) nickname = db.Column(db.String(64), unique = True)
@ -12,6 +17,12 @@ class User(db.Model):
posts = db.relationship('Post', backref = 'author', lazy = 'dynamic') posts = db.relationship('Post', backref = 'author', lazy = 'dynamic')
about_me = db.Column(db.String(140)) about_me = db.Column(db.String(140))
last_seen = db.Column(db.DateTime) last_seen = db.Column(db.DateTime)
followed = db.relationship('User',
secondary = followers,
primaryjoin = (followers.c.follower_id == id),
secondaryjoin = (followers.c.followed_id == id),
backref = db.backref('followers', lazy = 'dynamic'),
lazy = 'dynamic')
@staticmethod @staticmethod
def make_unique_nickname(nickname): def make_unique_nickname(nickname):
@ -40,6 +51,22 @@ class User(db.Model):
def avatar(self, size): def avatar(self, size):
return 'http://www.gravatar.com/avatar/' + md5(self.email).hexdigest() + '?d=mm&s=' + str(size) return 'http://www.gravatar.com/avatar/' + md5(self.email).hexdigest() + '?d=mm&s=' + str(size)
def follow(self, user):
if not self.is_following(user):
self.followed.append(user)
return self
def unfollow(self, user):
if self.is_following(user):
self.followed.remove(user)
return self
def is_following(self, user):
return self.followed.filter(followers.c.followed_id == user.id).count() > 0
def followed_posts(self):
return Post.query.join(followers, (followers.c.followed_id == Post.user_id)).filter(followers.c.follower_id == self.id).order_by(Post.timestamp.desc())
def __repr__(self): def __repr__(self):
return '<User %r>' % (self.nickname) return '<User %r>' % (self.nickname)

View File

@ -9,7 +9,15 @@
<h1>User: {{user.nickname}}</h1> <h1>User: {{user.nickname}}</h1>
{% if user.about_me %}<p>{{user.about_me}}</p>{% endif %} {% if user.about_me %}<p>{{user.about_me}}</p>{% endif %}
{% if user.last_seen %}<p><i>Last seen on: {{user.last_seen}}</i></p>{% endif %} {% if user.last_seen %}<p><i>Last seen on: {{user.last_seen}}</i></p>{% endif %}
{% if user.id == g.user.id %}<p><a href="{{url_for('edit')}}">Edit</a></p>{% endif %} <p>{{user.followers.count()}} followers |
{% if user.id == g.user.id %}
<a href="{{url_for('edit')}}">Edit your profile</a>
{% elif not g.user.is_following(user) %}
<a href="{{url_for('follow', nickname = user.nickname)}}">Follow</a>
{% else %}
<a href="{{url_for('unfollow', nickname = user.nickname)}}">Unfollow</a>
{% endif %}
</p>
</td> </td>
</tr> </tr>
</table> </table>

View File

@ -73,6 +73,9 @@ def after_login(resp):
user = User(nickname = nickname, email = resp.email, role = ROLE_USER) user = User(nickname = nickname, email = resp.email, role = ROLE_USER)
db.session.add(user) db.session.add(user)
db.session.commit() db.session.commit()
# make the user follow him/herself
db.session.add(user.follow(user))
db.session.commit()
remember_me = False remember_me = False
if 'remember_me' in session: if 'remember_me' in session:
remember_me = session['remember_me'] remember_me = session['remember_me']
@ -116,4 +119,33 @@ def edit():
form.about_me.data = g.user.about_me form.about_me.data = g.user.about_me
return render_template('edit.html', return render_template('edit.html',
form = form) form = form)
@app.route('/follow/<nickname>')
def follow(nickname):
user = User.query.filter_by(nickname = nickname).first()
if user == None:
flash('User ' + nickname + ' not found.')
return redirect(url_for('index'))
u = g.user.follow(user)
if u is None:
flash('Cannot follow ' + nickname + '.')
return redirect(url_for('user', nickname = nickname))
db.session.add(u)
db.session.commit()
flash('You are now following ' + nickname + '!')
return redirect(url_for('user', nickname = nickname))
@app.route('/unfollow/<nickname>')
def unfollow(nickname):
user = User.query.filter_by(nickname = nickname).first()
if user == None:
flash('User ' + nickname + ' not found.')
return redirect(url_for('index'))
u = g.user.unfollow(user)
if u is None:
flash('Cannot unfollow ' + nickname + '.')
return redirect(url_for('user', nickname = nickname))
db.session.add(u)
db.session.commit()
flash('You have stopped following ' + nickname + '.')
return redirect(url_for('user', nickname = nickname))

View File

@ -0,0 +1,26 @@
from sqlalchemy import *
from migrate import *
from migrate.changeset import schema
pre_meta = MetaData()
post_meta = MetaData()
followers = Table('followers', post_meta,
Column('follower_id', Integer),
Column('followed_id', Integer),
)
def upgrade(migrate_engine):
# Upgrade operations go here. Don't create your own engine; bind
# migrate_engine to your metadata
pre_meta.bind = migrate_engine
post_meta.bind = migrate_engine
post_meta.tables['followers'].create()
def downgrade(migrate_engine):
# Operations to reverse the above upgrade go here.
pre_meta.bind = migrate_engine
post_meta.bind = migrate_engine
post_meta.tables['followers'].drop()

View File

@ -1,10 +1,11 @@
#!flask/bin/python #!flask/bin/python
import os import os
import unittest import unittest
from datetime import datetime, timedelta
from config import basedir from config import basedir
from app import app, db from app import app, db
from app.models import User from app.models import User, Post
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
def setUp(self): def setUp(self):
@ -14,6 +15,7 @@ class TestCase(unittest.TestCase):
db.create_all() db.create_all()
def tearDown(self): def tearDown(self):
db.session.remove()
db.drop_all() db.drop_all()
def test_avatar(self): def test_avatar(self):
@ -38,5 +40,78 @@ class TestCase(unittest.TestCase):
assert nickname2 != 'john' assert nickname2 != 'john'
assert nickname2 != nickname assert nickname2 != nickname
def test_follow(self):
u1 = User(nickname = 'john', email = 'john@example.com')
u2 = User(nickname = 'susan', email = 'susan@example.com')
db.session.add(u1)
db.session.add(u2)
db.session.commit()
assert u1.unfollow(u2) == None
u = u1.follow(u2)
db.session.add(u)
db.session.commit()
assert u1.follow(u2) == None
assert u1.is_following(u2)
assert u1.followed.count() == 1
assert u1.followed.first().nickname == 'susan'
assert u2.followers.count() == 1
assert u2.followers.first().nickname == 'john'
u = u1.unfollow(u2)
assert u != None
db.session.add(u)
db.session.commit()
assert u1.is_following(u2) == False
assert u1.followed.count() == 0
assert u2.followers.count() == 0
def test_follow_posts(self):
# make four users
u1 = User(nickname = 'john', email = 'john@example.com')
u2 = User(nickname = 'susan', email = 'susan@example.com')
u3 = User(nickname = 'mary', email = 'mary@example.com')
u4 = User(nickname = 'david', email = 'david@example.com')
db.session.add(u1)
db.session.add(u2)
db.session.add(u3)
db.session.add(u4)
# make four posts
utcnow = datetime.utcnow()
p1 = Post(body = "post from john", author = u1, timestamp = utcnow + timedelta(seconds = 1))
p2 = Post(body = "post from susan", author = u2, timestamp = utcnow + timedelta(seconds = 2))
p3 = Post(body = "post from mary", author = u3, timestamp = utcnow + timedelta(seconds = 3))
p4 = Post(body = "post from david", author = u4, timestamp = utcnow + timedelta(seconds = 4))
db.session.add(p1)
db.session.add(p2)
db.session.add(p3)
db.session.add(p4)
db.session.commit()
# setup the followers
u1.follow(u1) # john follows himself
u1.follow(u2) # john follows susan
u1.follow(u4) # john follows david
u2.follow(u2) # susan follows herself
u2.follow(u3) # susan follows mary
u3.follow(u3) # mary follows herself
u3.follow(u4) # mary follows david
u4.follow(u4) # david follows himself
db.session.add(u1)
db.session.add(u2)
db.session.add(u3)
db.session.add(u4)
db.session.commit()
# check the followed posts of each user
f1 = u1.followed_posts().all()
f2 = u2.followed_posts().all()
f3 = u3.followed_posts().all()
f4 = u4.followed_posts().all()
assert len(f1) == 3
assert len(f2) == 2
assert len(f3) == 2
assert len(f4) == 1
assert f1 == [p4, p2, p1]
assert f2 == [p3, p2]
assert f3 == [p4, p3]
assert f4 == [p4]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()