diff --git a/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/database.py b/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/database.py index c1ffd95..98dd8d6 100644 --- a/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/database.py +++ b/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/database.py @@ -3,6 +3,38 @@ mixins. ''' -from flask.ext.sqlalchemy import SQLAlchemy +from .extensions import db -db = SQLAlchemy() +class CRUDMixin(object): + __table_args__ = {'extend_existing': True} + + id = db.Column(db.Integer, primary_key=True) + + @classmethod + def get_by_id(cls, id): + if any( + (isinstance(id, basestring) and id.isdigit(), + isinstance(id, (int, float))), + ): + return cls.query.get(int(id)) + return None + + @classmethod + def create(cls, **kwargs): + instance = cls(**kwargs) + return instance.save() + + def update(self, commit=True, **kwargs): + for attr, value in kwargs.iteritems(): + setattr(self, attr, value) + return commit and self.save() or self + + def save(self, commit=True): + db.session.add(self) + if commit: + db.session.commit() + return self + + def delete(self, commit=True): + db.session.delete(self) + return commit and db.session.commit() diff --git a/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/extensions.py b/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/extensions.py index 73c41bd..ec381fd 100644 --- a/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/extensions.py +++ b/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/extensions.py @@ -6,3 +6,6 @@ bcrypt = Bcrypt() from flask.ext.login import LoginManager login_manager = LoginManager() + +from flask.ext.sqlalchemy import SQLAlchemy +db = SQLAlchemy() diff --git a/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/public/views.py b/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/public/views.py index 875920e..ea55b96 100644 --- a/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/public/views.py +++ b/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/public/views.py @@ -3,7 +3,6 @@ from flask import (Blueprint, request, render_template, flash, url_for, redirect, session) from flask.ext.login import login_user, login_required, logout_user -from sqlalchemy.exc import IntegrityError from {{cookiecutter.repo_name}}.extensions import login_manager from {{cookiecutter.repo_name}}.user.models import User @@ -16,10 +15,7 @@ blueprint = Blueprint('public', __name__, static_folder="../static") @login_manager.user_loader def load_user(id): - try: - return User.query.get(int(id)) - except Exception: - return None + return User.get_by_id(int(id)) @blueprint.route("/", methods=["GET", "POST"]) diff --git a/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/templates/_layouts/nav.html b/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/templates/_layouts/nav.html index fb797a9..9e37955 100644 --- a/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/templates/_layouts/nav.html +++ b/{{cookiecutter.repo_name}}/{{cookiecutter.repo_name}}/templates/_layouts/nav.html @@ -17,7 +17,7 @@