from datetime import datetime, timezone from hashlib import md5 import json from time import time from typing import Optional import sqlalchemy as sa import sqlalchemy.orm as so from flask import current_app from flask_login import UserMixin from werkzeug.security import generate_password_hash, check_password_hash import jwt import redis import rq from app import db, login from app.search import add_to_index, remove_from_index, query_index class SearchableMixin: @classmethod def search(cls, expression, page, per_page): ids, total = query_index(cls.__tablename__, expression, page, per_page) if total == 0: return [], 0 when = [] for i in range(len(ids)): when.append((ids[i], i)) query = sa.select(cls).where(cls.id.in_(ids)).order_by( db.case(*when, value=cls.id)) return db.session.scalars(query), total @classmethod def before_commit(cls, session): session._changes = { 'add': list(session.new), 'update': list(session.dirty), 'delete': list(session.deleted) } @classmethod def after_commit(cls, session): for obj in session._changes['add']: if isinstance(obj, SearchableMixin): add_to_index(obj.__tablename__, obj) for obj in session._changes['update']: if isinstance(obj, SearchableMixin): add_to_index(obj.__tablename__, obj) for obj in session._changes['delete']: if isinstance(obj, SearchableMixin): remove_from_index(obj.__tablename__, obj) session._changes = None @classmethod def reindex(cls): for obj in db.session.scalars(sa.select(cls)): add_to_index(cls.__tablename__, obj) db.event.listen(db.session, 'before_commit', SearchableMixin.before_commit) db.event.listen(db.session, 'after_commit', SearchableMixin.after_commit) followers = sa.Table( 'followers', db.metadata, sa.Column('follower_id', sa.Integer, sa.ForeignKey('user.id'), primary_key=True), sa.Column('followed_id', sa.Integer, sa.ForeignKey('user.id'), primary_key=True) ) class User(UserMixin, db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) username: so.Mapped[str] = so.mapped_column(sa.String(64), index=True, unique=True) email: so.Mapped[str] = so.mapped_column(sa.String(120), index=True, unique=True) password_hash: so.Mapped[Optional[str]] = so.mapped_column(sa.String(256)) about_me: so.Mapped[Optional[str]] = so.mapped_column(sa.String(140)) last_seen: so.Mapped[Optional[datetime]] = so.mapped_column( default=lambda: datetime.now(timezone.utc)) last_message_read_time: so.Mapped[Optional[datetime]] posts: so.WriteOnlyMapped['Post'] = so.relationship( back_populates='author') following: so.WriteOnlyMapped['User'] = so.relationship( secondary=followers, primaryjoin=(followers.c.follower_id == id), secondaryjoin=(followers.c.followed_id == id), back_populates='followers') followers: so.WriteOnlyMapped['User'] = so.relationship( secondary=followers, primaryjoin=(followers.c.followed_id == id), secondaryjoin=(followers.c.follower_id == id), back_populates='following') messages_sent: so.WriteOnlyMapped['Message'] = so.relationship( foreign_keys='Message.sender_id', back_populates='author') messages_received: so.WriteOnlyMapped['Message'] = so.relationship( foreign_keys='Message.recipient_id', back_populates='recipient') notifications: so.WriteOnlyMapped['Notification'] = so.relationship( back_populates='user') tasks: so.WriteOnlyMapped['Task'] = so.relationship(back_populates='user') def __repr__(self): return ''.format(self.username) def set_password(self, password): self.password_hash = generate_password_hash(password) def check_password(self, password): return check_password_hash(self.password_hash, password) def avatar(self, size): digest = md5(self.email.lower().encode('utf-8')).hexdigest() return f'https://www.gravatar.com/avatar/{digest}?d=identicon&s={size}' def follow(self, user): if not self.is_following(user): self.following.add(user) def unfollow(self, user): if self.is_following(user): self.following.remove(user) def is_following(self, user): query = self.following.select().where(User.id == user.id) return db.session.scalar(query) is not None def followers_count(self): query = sa.select(sa.func.count()).select_from( self.followers.select().subquery()) return db.session.scalar(query) def following_count(self): query = sa.select(sa.func.count()).select_from( self.following.select().subquery()) return db.session.scalar(query) def following_posts(self): Author = so.aliased(User) Follower = so.aliased(User) return ( sa.select(Post) .join(Post.author.of_type(Author)) .join(Author.followers.of_type(Follower), isouter=True) .where(sa.or_( Follower.id == self.id, Author.id == self.id, )) .group_by(Post) .order_by(Post.timestamp.desc()) ) def get_reset_password_token(self, expires_in=600): return jwt.encode( {'reset_password': self.id, 'exp': time() + expires_in}, current_app.config['SECRET_KEY'], algorithm='HS256') @staticmethod def verify_reset_password_token(token): try: id = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=['HS256'])['reset_password'] except Exception: return return db.session.get(User, id) def unread_message_count(self): last_read_time = self.last_message_read_time or datetime(1900, 1, 1) query = sa.select(Message).where(Message.recipient == self, Message.timestamp > last_read_time) return db.session.scalar(sa.select(sa.func.count()).select_from( query.subquery())) def add_notification(self, name, data): db.session.execute(self.notifications.delete().where( Notification.name == name)) n = Notification(name=name, payload_json=json.dumps(data), user=self) db.session.add(n) return n def launch_task(self, name, description, *args, **kwargs): rq_job = current_app.task_queue.enqueue(f'app.tasks.{name}', self.id, *args, **kwargs) task = Task(id=rq_job.get_id(), name=name, description=description, user=self) db.session.add(task) return task def get_tasks_in_progress(self): query = self.tasks.select().where(Task.complete == False) return db.session.scalars(query) def get_task_in_progress(self, name): query = self.tasks.select().where(Task.name == name, Task.complete == False) return db.session.scalar(query) @login.user_loader def load_user(id): return db.session.get(User, int(id)) class Post(SearchableMixin, db.Model): __searchable__ = ['body'] id: so.Mapped[int] = so.mapped_column(primary_key=True) body: so.Mapped[str] = so.mapped_column(sa.String(140)) timestamp: so.Mapped[datetime] = so.mapped_column( index=True, default=lambda: datetime.now(timezone.utc)) user_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey(User.id), index=True) language: so.Mapped[Optional[str]] = so.mapped_column(sa.String(5)) author: so.Mapped[User] = so.relationship(back_populates='posts') def __repr__(self): return ''.format(self.body) class Message(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) sender_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey(User.id), index=True) recipient_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey(User.id), index=True) body: so.Mapped[str] = so.mapped_column(sa.String(140)) timestamp: so.Mapped[datetime] = so.mapped_column( index=True, default=lambda: datetime.now(timezone.utc)) author: so.Mapped[User] = so.relationship( foreign_keys='Message.sender_id', back_populates='messages_sent') recipient: so.Mapped[User] = so.relationship( foreign_keys='Message.recipient_id', back_populates='messages_received') def __repr__(self): return ''.format(self.body) class Notification(db.Model): id: so.Mapped[int] = so.mapped_column(primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String(128), index=True) user_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey(User.id), index=True) timestamp: so.Mapped[float] = so.mapped_column(index=True, default=time) payload_json: so.Mapped[str] = so.mapped_column(sa.Text) user: so.Mapped[User] = so.relationship(back_populates='notifications') def get_data(self): return json.loads(str(self.payload_json)) class Task(db.Model): id: so.Mapped[str] = so.mapped_column(sa.String(36), primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String(128), index=True) description: so.Mapped[Optional[str]] = so.mapped_column(sa.String(128)) user_id: so.Mapped[int] = so.mapped_column(sa.ForeignKey(User.id)) complete: so.Mapped[bool] = so.mapped_column(default=False) user: so.Mapped[User] = so.relationship(back_populates='tasks') def get_rq_job(self): try: rq_job = rq.job.Job.fetch(self.id, connection=current_app.redis) except (redis.exceptions.RedisError, rq.exceptions.NoSuchJobError): return None return rq_job def get_progress(self): job = self.get_rq_job() return job.meta.get('progress', 0) if job is not None else 100