Flask是一个用Python编写的轻量级Web应用框架,被称为"微框架"(microframework)。它设计简洁,核心功能精简,但通过扩展可以实现复杂的Web应用。
特性 | ️ Flask | Django |
---|---|---|
设计理念 | 微框架,最小化核心 | 全栈框架,功能完整 |
学习曲线 | 简单易学 | 相对复杂 |
灵活性 | 高度灵活 | 约定优于配置 |
内置功能 | 基础功能,需扩展 | 功能丰富,开箱即用 |
适用场景 | 小到中型项目、API | 大型复杂项目 |
# 1. 创建项目目录
mkdir flask-learning
cd flask-learning
# 2. 创建虚拟环境
python -m venv venv
# 3. 激活虚拟环境
# Windows
venv\Scripts\activate
# macOS/Linux
source venv/bin/activate
# 4. 安装Flask
pip install Flask
# 5. 验证安装
python -c "import flask; print(flask.__version__)"
# app.py
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello_world():
return 'Hello, Flask World! ️
'
if __name__ == '__main__':
app.run(debug=True)
# 运行应用
python app.py
# 访问 http://localhost:5000
Flask的路由系统是Web应用的核心,它决定了URL如何映射到Python函数。
基础路由示例:
from flask import Flask, request, url_for, redirect
app = Flask(__name__)
# 1. 基础路由
@app.route('/')
def index():
return 'Welcome to Flask!'
# 2. 带参数的路由
@app.route('/user/' )
def show_user(username):
return f'Hello, {username}!'
# 3. 指定参数类型
@app.route('/post/' )
def show_post(post_id):
return f'Post ID: {post_id} (type: {type(post_id)})'
@app.route('/price/' )
def show_price(price):
return f'Price: ${price:.2f}'
# 4. 多种HTTP方法
@app.route('/api/data', methods=['GET', 'POST', 'PUT', 'DELETE'])
def handle_data():
if request.method == 'GET':
return {'message': 'GET request received'}
elif request.method == 'POST':
return {'message': 'POST request received', 'data': request.json}
elif request.method == 'PUT':
return {'message': 'PUT request received'}
elif request.method == 'DELETE':
return {'message': 'DELETE request received'}
URL参数类型:
类型 | 说明 | 示例 |
---|---|---|
string |
默认类型,接受任何不包含斜杠的文本 | /user/ |
int |
接受正整数 | /post/ |
float |
接受正浮点数 | /price/ |
path |
类似string,但接受斜杠 | /file/ |
uuid |
接受UUID字符串 | /user/ |
高级路由技巧:
# 可选参数
@app.route('/users/')
@app.route('/users/' )
def show_users(page=1):
return f'Showing users page {page}'
# URL构建
@app.route('/build-url')
def build_url():
user_url = url_for('show_user', username='john')
post_url = url_for('show_post', post_id=123)
return f'User URL: {user_url}
Post URL: {post_url}'
# 重定向
@app.route('/old-page')
def old_page():
return redirect(url_for('index'))
自定义装饰器:
from functools import wraps
def require_api_key(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if request.args.get('api_key') != 'secret':
return {'error': 'Invalid API key'}, 401
return f(*args, **kwargs)
return decorated_function
@app.route('/protected')
@require_api_key
def protected():
return {'message': 'Protected data'}
请求钩子:
@app.before_request
def before_request():
print(f"Request started: {request.method} {request.url}")
@app.after_request
def after_request(response):
print("Request completed")
return response
Flask使用Jinja2作为模板引擎,它提供了强大的模板功能,包括模板继承、宏、过滤器等。
基本模板语法:
语法 | 用途 | 示例 |
---|---|---|
{{ }} |
输出变量 | {{ user.name }} |
{% %} |
控制语句 | {% if user %}...{% endif %} |
{# #} |
注释 | {# 这是注释 #} |
{% raw %} |
原始内容 | {% raw %}{{ 不被解析 }}{% endraw %} |
Flask视图函数示例:
# app.py
from flask import Flask, render_template
from datetime import datetime
app = Flask(__name__)
# 自定义模板过滤器
@app.template_filter('datetime')
def datetime_filter(value, format='%Y-%m-%d %H:%M'):
return value.strftime(format)
@app.template_filter('currency')
def currency_filter(value):
return f"${value:,.2f}"
# 自定义模板函数
@app.template_global()
def get_current_year():
return datetime.now().year
@app.route('/profile/' )
def profile(name):
user_data = {
'name': name,
'age': 25,
'skills': ['Python', 'Flask', 'JavaScript', 'Docker'],
'join_date': datetime(2020, 5, 15),
'salary': 75000,
'is_active': True
}
return render_template('profile.html', user=user_data)
模板继承结构:
DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>{% block title %}Flask App{% endblock %}title>
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet">
head>
<body>
<nav class="navbar navbar-expand-lg navbar-dark bg-dark">
<div class="container">
<a class="navbar-brand" href="{{ url_for('index') }}">Flask Appa>
div>
nav>
<div class="container mt-4">
{% block content %}{% endblock %}
div>
<footer class="bg-dark text-light text-center py-3 mt-5">
<p>© {{ get_current_year() }} Flask App. All rights reserved.p>
footer>
body>
html>
子模板示例:
{% extends "base.html" %}
{% block title %}{{ user.name }}'s Profile{% endblock %}
{% block content %}
<div class="row">
<div class="col-md-4">
<div class="card">
<div class="card-body text-center">
<h4>{{ user.name }}h4>
<p class="text-muted">
{% if user.is_active %}
<span class="badge bg-success">Activespan>
{% else %}
<span class="badge bg-secondary">Inactivespan>
{% endif %}
p>
<p><strong>Member since:strong> {{ user.join_date|datetime('%B %Y') }}p>
<p><strong>Salary:strong> {{ user.salary|currency }}p>
div>
div>
div>
<div class="col-md-8">
<div class="card">
<div class="card-header">
<h5>Skills ({{ user.skills|length }})h5>
div>
<div class="card-body">
{% for skill in user.skills %}
<span class="badge bg-primary me-2 mb-2">{{ skill }}span>
{% endfor %}
div>
div>
div>
div>
{% endblock %}
宏(Macros)- 可重用的模板片段:
{% macro render_field(field, label_class="", input_class="") %}
<div class="mb-3">
{{ field.label(class=label_class) }}
{{ field(class=input_class) }}
{% if field.errors %}
{% for error in field.errors %}
<div class="text-danger small">{{ error }}div>
{% endfor %}
{% endif %}
div>
{% endmacro %}
{% macro render_alert(message, category='info') %}
<div class="alert alert-{{ category }} alert-dismissible fade show">
{{ message }}
<button type="button" class="btn-close" data-bs-dismiss="alert">button>
div>
{% endmacro %}
使用宏:
{% from "macros.html" import render_field, render_alert %}
{{ render_alert('操作成功!', 'success') }}
{{ render_field(form.username, input_class="form-control") }}
Flask-WTF是Flask的表单处理扩展,它集成了WTForms库,提供了强大的表单验证和CSRF保护功能。
安装Flask-WTF:
pip install Flask-WTF
基础表单定义:
# forms.py
from flask_wtf import FlaskForm
from flask_wtf.file import FileField, FileAllowed, FileRequired
from wtforms import StringField, TextAreaField, PasswordField, SelectField, BooleanField, IntegerField, DateField, FloatField
from wtforms.validators import DataRequired, Email, Length, EqualTo, NumberRange, Optional, ValidationError
from wtforms.widgets import TextArea, PasswordInput
import re
class ContactForm(FlaskForm):
name = StringField('Full Name', validators=[
DataRequired(message='Name is required'),
Length(min=2, max=50, message='Name must be between 2 and 50 characters')
])
email = StringField('Email Address', validators=[
DataRequired(message='Email is required'),
Email(message='Invalid email address')
])
phone = StringField('Phone Number', validators=[
Optional(),
Length(min=10, max=15, message='Phone number must be between 10 and 15 digits')
])
subject = SelectField('Subject', choices=[
('general', 'General Inquiry'),
('support', 'Technical Support'),
('billing', 'Billing Question'),
('feedback', 'Feedback')
], validators=[DataRequired()])
message = TextAreaField('Message', validators=[
DataRequired(message='Message is required'),
Length(min=10, max=1000, message='Message must be between 10 and 1000 characters')
], widget=TextArea())
newsletter = BooleanField('Subscribe to newsletter')
def validate_phone(self, field):
if field.data:
# 验证电话号码格式
pattern = r'^\+?1?-?\.?\s?\(?(\d{3})\)?[\s.-]?(\d{3})[\s.-]?(\d{4})$'
if not re.match(pattern, field.data):
raise ValidationError('Invalid phone number format')
class UserRegistrationForm(FlaskForm):
username = StringField('Username', validators=[
DataRequired(),
Length(min=4, max=20, message='Username must be between 4 and 20 characters')
])
email = StringField('Email', validators=[
DataRequired(),
Email()
])
password = PasswordField('Password', validators=[
DataRequired(),
Length(min=8, message='Password must be at least 8 characters long')
], widget=PasswordInput(hide_value=False))
confirm_password = PasswordField('Confirm Password', validators=[
DataRequired(),
EqualTo('password', message='Passwords must match')
])
age = IntegerField('Age', validators=[
Optional(),
NumberRange(min=13, max=120, message='Age must be between 13 and 120')
])
birth_date = DateField('Birth Date', validators=[Optional()])
avatar = FileField('Profile Picture', validators=[
FileAllowed(['jpg', 'png', 'gif'], 'Images only!'),
Optional()
])
def validate_username(self, field):
# 自定义用户名验证
if not re.match(r'^[a-zA-Z0-9_]+$', field.data):
raise ValidationError('Username can only contain letters, numbers, and underscores')
# 检查用户名是否已存在(这里应该查询数据库)
forbidden_usernames = ['admin', 'root', 'administrator']
if field.data.lower() in forbidden_usernames:
raise ValidationError('This username is not allowed')
def validate_password(self, field):
# 密码强度验证
password = field.data
if not re.search(r'[A-Z]', password):
raise ValidationError('Password must contain at least one uppercase letter')
if not re.search(r'[a-z]', password):
raise ValidationError('Password must contain at least one lowercase letter')
if not re.search(r'\d', password):
raise ValidationError('Password must contain at least one digit')
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
raise ValidationError('Password must contain at least one special character')
class ProductForm(FlaskForm):
name = StringField('Product Name', validators=[
DataRequired(),
Length(min=2, max=100)
])
description = TextAreaField('Description', validators=[
Optional(),
Length(max=500)
])
price = FloatField('Price', validators=[
DataRequired(),
NumberRange(min=0.01, message='Price must be greater than 0')
])
category = SelectField('Category', coerce=int, validators=[DataRequired()])
in_stock = BooleanField('In Stock')
tags = StringField('Tags (comma separated)', validators=[Optional()])
def __init__(self, *args, **kwargs):
super(ProductForm, self).__init__(*args, **kwargs)
# 动态加载分类选项
self.category.choices = [(1, 'Electronics'), (2, 'Clothing'), (3, 'Books')]
# 动态表单生成
class DynamicForm(FlaskForm):
pass
def create_dynamic_form(fields_config):
"""根据配置动态创建表单"""
class DynamicFormClass(FlaskForm):
pass
for field_name, field_config in fields_config.items():
field_type = field_config['type']
validators = field_config.get('validators', [])
if field_type == 'string':
field = StringField(field_config['label'], validators=validators)
elif field_type == 'email':
field = StringField(field_config['label'], validators=validators + [Email()])
elif field_type == 'textarea':
field = TextAreaField(field_config['label'], validators=validators)
elif field_type == 'select':
field = SelectField(field_config['label'],
choices=field_config['choices'],
validators=validators)
elif field_type == 'boolean':
field = BooleanField(field_config['label'])
setattr(DynamicFormClass, field_name, field)
return DynamicFormClass()
Flask应用配置:
# app.py - 表单处理视图
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify
from flask_wtf.csrf import CSRFProtect
import os
from werkzeug.utils import secure_filename
app = Flask(__name__)
app.secret_key = 'your-secret-key-here'
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
csrf = CSRFProtect(app)
@app.route('/contact', methods=['GET', 'POST'])
def contact():
form = ContactForm()
if form.validate_on_submit():
# 处理表单数据
contact_data = {
'name': form.name.data,
'email': form.email.data,
'phone': form.phone.data,
'subject': form.subject.data,
'message': form.message.data,
'newsletter': form.newsletter.data
}
# 这里可以保存到数据库或发送邮件
print(f"Contact form submitted: {contact_data}")
flash('Thank you for your message! We will get back to you soon.', 'success')
return redirect(url_for('contact'))
return render_template('contact_form.html', form=form)
表单模板示例:
{% extends "base.html" %}
{% block title %}Contact Us{% endblock %}
{% block content %}
<div class="row justify-content-center">
<div class="col-md-8">
<div class="card">
<div class="card-header">
<h4 class="mb-0">Contact Ush4>
div>
<div class="card-body">
<form method="POST">
{{ form.hidden_tag() }}
<div class="row">
<div class="col-md-6">
<div class="mb-3">
{{ form.name.label(class="form-label") }}
{{ form.name(class="form-control") }}
{% if form.name.errors %}
{% for error in form.name.errors %}
<div class="text-danger small">{{ error }}div>
{% endfor %}
{% endif %}
div>
div>
<div class="col-md-6">
<div class="mb-3">
{{ form.email.label(class="form-label") }}
{{ form.email(class="form-control") }}
{% if form.email.errors %}
{% for error in form.email.errors %}
<div class="text-danger small">{{ error }}div>
{% endfor %}
{% endif %}
div>
div>
div>
<div class="row">
<div class="col-md-6">
<div class="mb-3">
{{ form.phone.label(class="form-label") }}
{{ form.phone(class="form-control") }}
{% if form.phone.errors %}
{% for error in form.phone.errors %}
<div class="text-danger small">{{ error }}div>
{% endfor %}
{% endif %}
div>
div>
<div class="col-md-6">
<div class="mb-3">
{{ form.subject.label(class="form-label") }}
{{ form.subject(class="form-select") }}
{% if form.subject.errors %}
{% for error in form.subject.errors %}
<div class="text-danger small">{{ error }}div>
{% endfor %}
{% endif %}
div>
div>
div>
<div class="mb-3">
{{ form.message.label(class="form-label") }}
{{ form.message(class="form-control", rows="5") }}
{% if form.message.errors %}
{% for error in form.message.errors %}
<div class="text-danger small">{{ error }}div>
{% endfor %}
{% endif %}
div>
<div class="form-check mb-3">
{{ form.newsletter(class="form-check-input") }}
{{ form.newsletter.label(class="form-check-label") }}
div>
<div class="d-grid">
<button type="submit" class="btn btn-primary">Send Messagebutton>
div>
form>
div>
div>
div>
div>
{% endblock %}
@app.route('/register', methods=['GET', 'POST'])
def register():
form = UserRegistrationForm()
if form.validate_on_submit():
# 处理文件上传
avatar_filename = None
if form.avatar.data:
avatar_file = form.avatar.data
avatar_filename = secure_filename(avatar_file.filename)
avatar_path = os.path.join(app.config['UPLOAD_FOLDER'], avatar_filename)
avatar_file.save(avatar_path)
user_data = {
'username': form.username.data,
'email': form.email.data,
'age': form.age.data,
'birth_date': form.birth_date.data,
'avatar': avatar_filename
}
# 保存用户数据到数据库
print(f"User registered: {user_data}")
flash('Registration successful! Please log in.', 'success')
return redirect(url_for('login'))
return render_template('register_form.html', form=form)
@app.route('/product/new', methods=['GET', 'POST'])
def new_product():
form = ProductForm()
if form.validate_on_submit():
# 处理标签
tags = [tag.strip() for tag in form.tags.data.split(',') if tag.strip()]
product_data = {
'name': form.name.data,
'description': form.description.data,
'price': form.price.data,
'category': form.category.data,
'in_stock': form.in_stock.data,
'tags': tags
}
print(f"Product created: {product_data}")
flash('Product created successfully!', 'success')
return redirect(url_for('product_list'))
return render_template('product_form.html', form=form)
# AJAX表单验证
@app.route('/validate-username', methods=['POST'])
def validate_username():
username = request.json.get('username')
# 模拟数据库查询
existing_usernames = ['admin', 'user1', 'testuser']
if username in existing_usernames:
return jsonify({'valid': False, 'message': 'Username already exists'})
if len(username) < 4:
return jsonify({'valid': False, 'message': 'Username too short'})
return jsonify({'valid': True, 'message': 'Username available'})
# 多步骤表单
@app.route('/wizard', methods=['GET', 'POST'])
def form_wizard():
step = request.args.get('step', 1, type=int)
if step == 1:
# 第一步:基本信息
if request.method == 'POST':
session['step1_data'] = request.form.to_dict()
return redirect(url_for('form_wizard', step=2))
return render_template('wizard_step1.html')
elif step == 2:
# 第二步:详细信息
if request.method == 'POST':
session['step2_data'] = request.form.to_dict()
return redirect(url_for('form_wizard', step=3))
return render_template('wizard_step2.html')
elif step == 3:
# 第三步:确认
if request.method == 'POST':
# 合并所有步骤的数据
final_data = {
**session.get('step1_data', {}),
**session.get('step2_data', {}),
**request.form.to_dict()
}
# 处理最终数据
print(f"Wizard completed: {final_data}")
# 清理session
session.pop('step1_data', None)
session.pop('step2_data', None)
flash('Form submitted successfully!', 'success')
return redirect(url_for('index'))
return render_template('wizard_step3.html',
step1_data=session.get('step1_data', {}),
step2_data=session.get('step2_data', {}))
表单模板与样式
{% extends "base.html" %}
{% from "macros.html" import render_field %}
{% block title %}Contact Us{% endblock %}
{% block extra_css %}
<style>
.form-container {
max-width: 600px;
margin: 0 auto;
padding: 2rem;
background: #f8f9fa;
border-radius: 10px;
box-shadow: 0 0 20px rgba(0,0,0,0.1);
}
.form-group {
margin-bottom: 1.5rem;
}
.form-control:focus {
border-color: #007bff;
box-shadow: 0 0 0 0.2rem rgba(0,123,255,.25);
}
.btn-submit {
background: linear-gradient(45deg, #007bff, #0056b3);
border: none;
padding: 12px 30px;
border-radius: 25px;
color: white;
font-weight: bold;
transition: all 0.3s ease;
}
.btn-submit:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(0,123,255,0.4);
}
style>
{% endblock %}
{% block content %}
<div class="form-container">
<h2 class="text-center mb-4">Get In Touchh2>
<form method="POST" novalidate>
{{ form.hidden_tag() }}
<div class="row">
<div class="col-md-6">
{{ render_field(form.name, input_class="form-control") }}
div>
<div class="col-md-6">
{{ render_field(form.email, input_class="form-control") }}
div>
div>
<div class="row">
<div class="col-md-6">
{{ render_field(form.phone, input_class="form-control") }}
div>
<div class="col-md-6">
{{ render_field(form.subject, input_class="form-select") }}
div>
div>
{{ render_field(form.message, input_class="form-control", style="height: 120px;") }}
<div class="form-check mb-3">
{{ form.newsletter(class="form-check-input") }}
{{ form.newsletter.label(class="form-check-label") }}
div>
<div class="text-center">
<button type="submit" class="btn btn-submit">
<i class="fas fa-paper-plane me-2">i>Send Message
button>
div>
form>
div>
{% endblock %}
{% block extra_js %}
<script>
// 实时表单验证
document.addEventListener('DOMContentLoaded', function() {
const form = document.querySelector('form');
const inputs = form.querySelectorAll('input, textarea, select');
inputs.forEach(input => {
input.addEventListener('blur', function() {
validateField(this);
});
input.addEventListener('input', function() {
clearErrors(this);
});
});
function validateField(field) {
const value = field.value.trim();
const fieldName = field.name;
// 清除之前的错误
clearErrors(field);
// 基本验证
if (field.hasAttribute('required') && !value) {
showError(field, 'This field is required');
return false;
}
// 邮箱验证
if (fieldName === 'email' && value) {
const emailRegex = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
if (!emailRegex.test(value)) {
showError(field, 'Please enter a valid email address');
return false;
}
}
// 电话验证
if (fieldName === 'phone' && value) {
const phoneRegex = /^\+?1?-?\.?\s?\(?(\d{3})\)?[\s.-]?(\d{3})[\s.-]?(\d{4})$/;
if (!phoneRegex.test(value)) {
showError(field, 'Please enter a valid phone number');
return false;
}
}
return true;
}
function showError(field, message) {
field.classList.add('is-invalid');
const errorDiv = document.createElement('div');
errorDiv.className = 'invalid-feedback';
errorDiv.textContent = message;
field.parentNode.appendChild(errorDiv);
}
function clearErrors(field) {
field.classList.remove('is-invalid');
const errorDiv = field.parentNode.querySelector('.invalid-feedback');
if (errorDiv) {
errorDiv.remove();
}
}
// 表单提交验证
form.addEventListener('submit', function(e) {
let isValid = true;
inputs.forEach(input => {
if (!validateField(input)) {
isValid = false;
}
});
if (!isValid) {
e.preventDefault();
const firstError = form.querySelector('.is-invalid');
if (firstError) {
firstError.focus();
firstError.scrollIntoView({ behavior: 'smooth', block: 'center' });
}
}
});
});
script>
{% endblock %}
Flask-SQLAlchemy是Flask的数据库扩展,它简化了SQLAlchemy在Flask应用中的使用。
安装依赖:
pip install Flask-SQLAlchemy
pip install PyMySQL # MySQL
# 或
pip install psycopg2 # PostgreSQL
数据库配置:
# config.py
import os
class Config:
SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key'
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_RECORD_QUERIES = True
class DevelopmentConfig(Config):
DEBUG = True
SQLALCHEMY_DATABASE_URI = os.environ.get('DEV_DATABASE_URL') or \
'sqlite:///app.db'
class ProductionConfig(Config):
DEBUG = False
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \
'postgresql://user:password@localhost/app_db'
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'default': DevelopmentConfig
}
基础模型示例:
# models.py
from flask_sqlalchemy import SQLAlchemy
from flask_login import UserMixin
from werkzeug.security import generate_password_hash, check_password_hash
from datetime import datetime
db = SQLAlchemy()
class User(UserMixin, db.Model):
__tablename__ = 'users'
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False, index=True)
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(255), nullable=False)
# 个人信息
first_name = db.Column(db.String(50))
last_name = db.Column(db.String(50))
bio = db.Column(db.Text)
# 状态字段
is_active = db.Column(db.Boolean, default=True, nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# 关系
posts = db.relationship('Post', backref='author', lazy='dynamic', cascade='all, delete-orphan')
def set_password(self, password):
self.password_hash = generate_password_hash(password)
def check_password(self, password):
return check_password_hash(self.password_hash, password)
@property
def full_name(self):
if self.first_name and self.last_name:
return f"{self.first_name} {self.last_name}"
return self.username
def to_dict(self):
return {
'id': self.id,
'username': self.username,
'email': self.email,
'full_name': self.full_name,
'is_active': self.is_active,
'created_at': self.created_at.isoformat()
}
def __repr__(self):
return f'{self.username}>'
class Post(db.Model):
__tablename__ = 'posts'
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(200), nullable=False)
content = db.Column(db.Text, nullable=False)
# 状态字段
is_published = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# 外键
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
def to_dict(self):
return {
'id': self.id,
'title': self.title,
'content': self.content,
'is_published': self.is_published,
'created_at': self.created_at.isoformat(),
'author': self.author.username
}
def __repr__(self):
return f'{self.title}>'
# app.py
from flask import Flask, render_template, request, redirect, url_for, flash
from models import db, User, Post
from config import config
def create_app(config_name='development'):
app = Flask(__name__)
app.config.from_object(config[config_name])
# 初始化扩展
db.init_app(app)
# 创建数据库表
with app.app_context():
db.create_all()
return app
app = create_app()
@app.route('/')
def index():
posts = Post.query.filter_by(is_published=True).order_by(Post.created_at.desc()).limit(10).all()
return render_template('index.html', posts=posts)
@app.route('/users')
def users():
page = request.args.get('page', 1, type=int)
users = User.query.paginate(
page=page, per_page=10, error_out=False
)
return render_template('users.html', users=users)
@app.route('/user/' )
def user_profile(user_id):
user = User.query.get_or_404(user_id)
posts = user.posts.filter_by(is_published=True).order_by(Post.created_at.desc()).all()
return render_template('user_profile.html', user=user, posts=posts)
# models.py - 完整的数据模型设计
from flask_sqlalchemy import SQLAlchemy
from flask_login import UserMixin
from werkzeug.security import generate_password_hash, check_password_hash
from datetime import datetime, timedelta
from sqlalchemy import event, Index
from sqlalchemy.ext.hybrid import hybrid_property
import uuid
db = SQLAlchemy()
# 关联表(多对多关系)
user_roles = db.Table('user_roles',
db.Column('user_id', db.Integer, db.ForeignKey('user.id'), primary_key=True),
db.Column('role_id', db.Integer, db.ForeignKey('role.id'), primary_key=True)
)
post_tags = db.Table('post_tags',
db.Column('post_id', db.Integer, db.ForeignKey('post.id'), primary_key=True),
db.Column('tag_id', db.Integer, db.ForeignKey('tag.id'), primary_key=True)
)
class TimestampMixin:
"""时间戳混入类"""
created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
class User(UserMixin, TimestampMixin, db.Model):
__tablename__ = 'user'
id = db.Column(db.Integer, primary_key=True)
uuid = db.Column(db.String(36), unique=True, nullable=False, default=lambda: str(uuid.uuid4()))
username = db.Column(db.String(80), unique=True, nullable=False, index=True)
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(255), nullable=False)
# 个人信息
first_name = db.Column(db.String(50))
last_name = db.Column(db.String(50))
avatar_url = db.Column(db.String(255))
bio = db.Column(db.Text)
birth_date = db.Column(db.Date)
# 状态字段
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_verified = db.Column(db.Boolean, default=False, nullable=False)
last_login = db.Column(db.DateTime)
login_count = db.Column(db.Integer, default=0)
# 关系
posts = db.relationship('Post', backref='author', lazy='dynamic', cascade='all, delete-orphan')
comments = db.relationship('Comment', backref='author', lazy='dynamic', cascade='all, delete-orphan')
roles = db.relationship('Role', secondary=user_roles, backref=db.backref('users', lazy='dynamic'))
# 索引
__table_args__ = (
Index('idx_user_email_active', 'email', 'is_active'),
Index('idx_user_username_active', 'username', 'is_active'),
)
def __init__(self, **kwargs):
super(User, self).__init__(**kwargs)
if not self.uuid:
self.uuid = str(uuid.uuid4())
@hybrid_property
def full_name(self):
if self.first_name and self.last_name:
return f"{self.first_name} {self.last_name}"
return self.username
@hybrid_property
def age(self):
if self.birth_date:
today = datetime.now().date()
return today.year - self.birth_date.year - ((today.month, today.day) < (self.birth_date.month, self.birth_date.day))
return None
def set_password(self, password):
self.password_hash = generate_password_hash(password)
def check_password(self, password):
return check_password_hash(self.password_hash, password)
def has_role(self, role_name):
return any(role.name == role_name for role in self.roles)
def add_role(self, role):
if not self.has_role(role.name):
self.roles.append(role)
def remove_role(self, role):
if self.has_role(role.name):
self.roles.remove(role)
def get_posts_count(self):
return self.posts.count()
def get_recent_posts(self, limit=5):
return self.posts.order_by(Post.created_at.desc()).limit(limit)
def to_dict(self, include_email=False):
data = {
'id': self.id,
'uuid': self.uuid,
'username': self.username,
'full_name': self.full_name,
'avatar_url': self.avatar_url,
'bio': self.bio,
'is_active': self.is_active,
'created_at': self.created_at.isoformat(),
'posts_count': self.get_posts_count()
}
if include_email:
data['email'] = self.email
return data
def __repr__(self):
return f'{self.username}>'
class Role(db.Model):
__tablename__ = 'role'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), unique=True, nullable=False)
description = db.Column(db.String(255))
permissions = db.Column(db.JSON) # 存储权限列表
def __repr__(self):
return f'{self.name}>'
class Category(TimestampMixin, db.Model):
__tablename__ = 'category'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), unique=True, nullable=False)
slug = db.Column(db.String(50), unique=True, nullable=False)
description = db.Column(db.Text)
color = db.Column(db.String(7), default='#007bff') # 十六进制颜色
is_active = db.Column(db.Boolean, default=True)
# 自引用关系(父子分类)
parent_id = db.Column(db.Integer, db.ForeignKey('category.id'))
children = db.relationship('Category', backref=db.backref('parent', remote_side=[id]))
# 关系
posts = db.relationship('Post', backref='category', lazy='dynamic')
def get_posts_count(self):
return self.posts.filter_by(is_published=True).count()
def __repr__(self):
return f'{self.name}>'
class Tag(db.Model):
__tablename__ = 'tag'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), unique=True, nullable=False)
slug = db.Column(db.String(50), unique=True, nullable=False)
color = db.Column(db.String(7), default='#6c757d')
def get_posts_count(self):
return len(self.posts)
def __repr__(self):
return f'{self.name}>'
class Post(TimestampMixin, db.Model):
__tablename__ = 'post'
id = db.Column(db.Integer, primary_key=True)
uuid = db.Column(db.String(36), unique=True, nullable=False, default=lambda: str(uuid.uuid4()))
title = db.Column(db.String(200), nullable=False)
slug = db.Column(db.String(200), unique=True, nullable=False)
content = db.Column(db.Text, nullable=False)
excerpt = db.Column(db.Text)
featured_image = db.Column(db.String(255))
# 状态字段
is_published = db.Column(db.Boolean, default=False)
is_featured = db.Column(db.Boolean, default=False)
published_at = db.Column(db.DateTime)
# 统计字段
view_count = db.Column(db.Integer, default=0)
like_count = db.Column(db.Integer, default=0)
comment_count = db.Column(db.Integer, default=0)
# 外键
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
category_id = db.Column(db.Integer, db.ForeignKey('category.id'))
# 关系
comments = db.relationship('Comment', backref='post', lazy='dynamic', cascade='all, delete-orphan')
tags = db.relationship('Tag', secondary=post_tags, backref=db.backref('posts', lazy='dynamic'))
# 索引
__table_args__ = (
Index('idx_post_published', 'is_published', 'published_at'),
Index('idx_post_author', 'user_id', 'is_published'),
Index('idx_post_category', 'category_id', 'is_published'),
)
def __init__(self, **kwargs):
super(Post, self).__init__(**kwargs)
if not self.uuid:
self.uuid = str(uuid.uuid4())
if not self.excerpt and self.content:
self.excerpt = self.content[:200] + '...' if len(self.content) > 200 else self.content
@hybrid_property
def reading_time(self):
"""估算阅读时间(分钟)"""
words_per_minute = 200
word_count = len(self.content.split())
return max(1, round(word_count / words_per_minute))
def publish(self):
self.is_published = True
self.published_at = datetime.utcnow()
def unpublish(self):
self.is_published = False
self.published_at = None
def increment_view_count(self):
self.view_count += 1
db.session.commit()
def add_tag(self, tag):
if tag not in self.tags:
self.tags.append(tag)
def remove_tag(self, tag):
if tag in self.tags:
self.tags.remove(tag)
def get_related_posts(self, limit=5):
"""获取相关文章"""
tag_ids = [tag.id for tag in self.tags]
if not tag_ids:
return Post.query.filter(
Post.id != self.id,
Post.is_published == True,
Post.category_id == self.category_id
).limit(limit).all()
return Post.query.join(post_tags).filter(
post_tags.c.tag_id.in_(tag_ids),
Post.id != self.id,
Post.is_published == True
).limit(limit).all()
def to_dict(self):
return {
'id': self.id,
'uuid': self.uuid,
'title': self.title,
'slug': self.slug,
'excerpt': self.excerpt,
'content': self.content,
'featured_image': self.featured_image,
'is_published': self.is_published,
'published_at': self.published_at.isoformat() if self.published_at else None,
'view_count': self.view_count,
'like_count': self.like_count,
'comment_count': self.comment_count,
'reading_time': self.reading_time,
'author': self.author.to_dict(),
'category': self.category.name if self.category else None,
'tags': [tag.name for tag in self.tags],
'created_at': self.created_at.isoformat(),
'updated_at': self.updated_at.isoformat()
}
def __repr__(self):
return f'{self.title}>'
class Comment(TimestampMixin, db.Model):
__tablename__ = 'comment'
id = db.Column(db.Integer, primary_key=True)
content = db.Column(db.Text, nullable=False)
is_approved = db.Column(db.Boolean, default=False)
# 外键
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
post_id = db.Column(db.Integer, db.ForeignKey('post.id'), nullable=False)
# 自引用关系(回复评论)
parent_id = db.Column(db.Integer, db.ForeignKey('comment.id'))
replies = db.relationship('Comment', backref=db.backref('parent', remote_side=[id]))
def approve(self):
self.is_approved = True
def get_replies_count(self):
return len(self.replies)
def __repr__(self):
return f'{self.id}>'
# 数据库事件监听器
@event.listens_for(Post, 'before_insert')
@event.listens_for(Post, 'before_update')
def generate_slug(mapper, connection, target):
"""自动生成slug"""
if target.title and not target.slug:
import re
slug = re.sub(r'[^\w\s-]', '', target.title.lower())
slug = re.sub(r'[-\s]+', '-', slug)
target.slug = slug
@event.listens_for(Comment, 'after_insert')
def update_comment_count(mapper, connection, target):
"""更新文章评论数"""
post = Post.query.get(target.post_id)
if post:
post.comment_count = Comment.query.filter_by(post_id=target.post_id, is_approved=True).count()
db.session.commit()
高级查询与数据库操作
# database_operations.py
from flask import Flask, request, jsonify
from sqlalchemy import func, and_, or_, not_, text, case
from sqlalchemy.orm import joinedload, selectinload, contains_eager
from models import db, User, Post, Comment, Category, Tag
from datetime import datetime, timedelta
app = Flask(__name__)
class DatabaseService:
"""数据库服务类,封装常用的数据库操作"""
@staticmethod
def get_user_stats():
"""获取用户统计信息"""
total_users = User.query.count()
active_users = User.query.filter_by(is_active=True).count()
verified_users = User.query.filter_by(is_verified=True).count()
# 最近30天注册的用户
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
recent_users = User.query.filter(User.created_at >= thirty_days_ago).count()
return {
'total_users': total_users,
'active_users': active_users,
'verified_users': verified_users,
'recent_users': recent_users,
'inactive_users': total_users - active_users
}
@staticmethod
def get_popular_posts(limit=10):
"""获取热门文章"""
return Post.query.filter_by(is_published=True)\
.order_by(Post.view_count.desc(), Post.like_count.desc())\
.limit(limit).all()
@staticmethod
def search_posts(query, category_id=None, tag_ids=None, page=1, per_page=10):
"""搜索文章"""
posts_query = Post.query.filter(
Post.is_published == True,
or_(
Post.title.contains(query),
Post.content.contains(query),
Post.excerpt.contains(query)
)
)
if category_id:
posts_query = posts_query.filter_by(category_id=category_id)
if tag_ids:
posts_query = posts_query.join(Post.tags).filter(Tag.id.in_(tag_ids))
return posts_query.order_by(Post.published_at.desc())\
.paginate(page=page, per_page=per_page, error_out=False)
@staticmethod
def get_user_activity(user_id, days=30):
"""获取用户活动统计"""
start_date = datetime.utcnow() - timedelta(days=days)
# 用户发布的文章数
posts_count = Post.query.filter(
Post.user_id == user_id,
Post.created_at >= start_date
).count()
# 用户的评论数
comments_count = Comment.query.filter(
Comment.user_id == user_id,
Comment.created_at >= start_date
).count()
# 用户文章的总浏览量
total_views = db.session.query(func.sum(Post.view_count))\
.filter(Post.user_id == user_id).scalar() or 0
return {
'posts_count': posts_count,
'comments_count': comments_count,
'total_views': total_views
}
@staticmethod
def get_category_stats():
"""获取分类统计"""
stats = db.session.query(
Category.id,
Category.name,
func.count(Post.id).label('posts_count'),
func.sum(Post.view_count).label('total_views')
).outerjoin(Post, and_(Post.category_id == Category.id, Post.is_published == True))\
.group_by(Category.id, Category.name)\
.order_by(func.count(Post.id).desc()).all()
return [
{
'id': stat.id,
'name': stat.name,
'posts_count': stat.posts_count or 0,
'total_views': stat.total_views or 0
}
for stat in stats
]
@staticmethod
def get_trending_tags(limit=20):
"""获取热门标签"""
trending = db.session.query(
Tag.id,
Tag.name,
func.count(Post.id).label('posts_count')
).join(Tag.posts)\
.filter(Post.is_published == True)\
.group_by(Tag.id, Tag.name)\
.order_by(func.count(Post.id).desc())\
.limit(limit).all()
return [
{
'id': tag.id,
'name': tag.name,
'posts_count': tag.posts_count
}
for tag in trending
]
@staticmethod
def get_monthly_post_stats(year=None):
"""获取月度文章统计"""
if not year:
year = datetime.utcnow().year
stats = db.session.query(
func.extract('month', Post.published_at).label('month'),
func.count(Post.id).label('posts_count')
).filter(
Post.is_published == True,
func.extract('year', Post.published_at) == year
).group_by(func.extract('month', Post.published_at))\
.order_by(func.extract('month', Post.published_at)).all()
# 填充所有月份
monthly_stats = {i: 0 for i in range(1, 13)}
for stat in stats:
monthly_stats[int(stat.month)] = stat.posts_count
return monthly_stats
# 复杂查询示例
@app.route('/api/advanced-search')
def advanced_search():
"""高级搜索API"""
# 获取查询参数
query = request.args.get('q', '')
category_id = request.args.get('category_id', type=int)
tag_names = request.args.getlist('tags')
author_id = request.args.get('author_id', type=int)
date_from = request.args.get('date_from')
date_to = request.args.get('date_to')
sort_by = request.args.get('sort_by', 'published_at')
order = request.args.get('order', 'desc')
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
# 构建查询
posts_query = Post.query.filter_by(is_published=True)
# 文本搜索
if query:
posts_query = posts_query.filter(
or_(
Post.title.ilike(f'%{query}%'),
Post.content.ilike(f'%{query}%'),
Post.excerpt.ilike(f'%{query}%')
)
)
# 分类筛选
if category_id:
posts_query = posts_query.filter_by(category_id=category_id)
# 标签筛选
if tag_names:
posts_query = posts_query.join(Post.tags).filter(Tag.name.in_(tag_names))
# 作者筛选
if author_id:
posts_query = posts_query.filter_by(user_id=author_id)
# 日期范围筛选
if date_from:
try:
date_from = datetime.strptime(date_from, '%Y-%m-%d')
posts_query = posts_query.filter(Post.published_at >= date_from)
except ValueError:
pass
if date_to:
try:
date_to = datetime.strptime(date_to, '%Y-%m-%d')
posts_query = posts_query.filter(Post.published_at <= date_to)
except ValueError:
pass
# 排序
if sort_by == 'published_at':
order_by = Post.published_at.desc() if order == 'desc' else Post.published_at.asc()
elif sort_by == 'view_count':
order_by = Post.view_count.desc() if order == 'desc' else Post.view_count.asc()
elif sort_by == 'like_count':
order_by = Post.like_count.desc() if order == 'desc' else Post.like_count.asc()
elif sort_by == 'title':
order_by = Post.title.desc() if order == 'desc' else Post.title.asc()
else:
order_by = Post.published_at.desc()
posts_query = posts_query.order_by(order_by)
# 预加载关联数据
posts_query = posts_query.options(
joinedload(Post.author),
joinedload(Post.category),
selectinload(Post.tags)
)
# 分页
pagination = posts_query.paginate(page=page, per_page=per_page, error_out=False)
return jsonify({
'posts': [post.to_dict() for post in pagination.items],
'pagination': {
'page': pagination.page,
'pages': pagination.pages,
'per_page': pagination.per_page,
'total': pagination.total,
'has_next': pagination.has_next,
'has_prev': pagination.has_prev
}
})
@app.route('/api/dashboard-stats')
def dashboard_stats():
"""仪表板统计数据"""
# 基础统计
user_stats = DatabaseService.get_user_stats()
category_stats = DatabaseService.get_category_stats()
trending_tags = DatabaseService.get_trending_tags(10)
monthly_stats = DatabaseService.get_monthly_post_stats()
# 最近活动
recent_posts = Post.query.filter_by(is_published=True)\
.order_by(Post.published_at.desc())\
.limit(5).all()
recent_comments = Comment.query.filter_by(is_approved=True)\
.order_by(Comment.created_at.desc())\
.limit(5).all()
# 热门文章
popular_posts = DatabaseService.get_popular_posts(5)
return jsonify({
'user_stats': user_stats,
'category_stats': category_stats,
'trending_tags': trending_tags,
'monthly_stats': monthly_stats,
'recent_posts': [post.to_dict() for post in recent_posts],
'recent_comments': [
{
'id': comment.id,
'content': comment.content[:100] + '...' if len(comment.content) > 100 else comment.content,
'author': comment.author.username,
'post_title': comment.post.title,
'created_at': comment.created_at.isoformat()
}
for comment in recent_comments
],
'popular_posts': [post.to_dict() for post in popular_posts]
})
# 原生SQL查询示例
@app.route('/api/custom-reports')
def custom_reports():
"""自定义报表 - 使用原生SQL"""
# 用户参与度报表
user_engagement_sql = text("""
SELECT
u.id,
u.username,
u.email,
COUNT(DISTINCT p.id) as posts_count,
COUNT(DISTINCT c.id) as comments_count,
SUM(p.view_count) as total_views,
AVG(p.view_count) as avg_views_per_post,
MAX(p.published_at) as last_post_date
FROM user u
LEFT JOIN post p ON u.id = p.user_id AND p.is_published = 1
LEFT JOIN comment c ON u.id = c.user_id AND c.is_approved = 1
WHERE u.is_active = 1
GROUP BY u.id, u.username, u.email
HAVING posts_count > 0
ORDER BY total_views DESC
LIMIT 20
""")
result = db.session.execute(user_engagement_sql)
user_engagement = [
{
'id': row.id,
'username': row.username,
'email': row.email,
'posts_count': row.posts_count,
'comments_count': row.comments_count,
'total_views': row.total_views or 0,
'avg_views_per_post': float(row.avg_views_per_post) if row.avg_views_per_post else 0,
'last_post_date': row.last_post_date.isoformat() if row.last_post_date else None
}
for row in result
]
# 内容表现报表
content_performance_sql = text("""
SELECT
p.id,
p.title,
p.published_at,
p.view_count,
p.like_count,
p.comment_count,
u.username as author,
c.name as category,
(p.view_count + p.like_count * 2 + p.comment_count * 3) as engagement_score
FROM post p
JOIN user u ON p.user_id = u.id
LEFT JOIN category c ON p.category_id = c.id
WHERE p.is_published = 1
ORDER BY engagement_score DESC
LIMIT 20
""")
result = db.session.execute(content_performance_sql)
content_performance = [
{
'id': row.id,
'title': row.title,
'published_at': row.published_at.isoformat() if row.published_at else None,
'view_count': row.view_count,
'like_count': row.like_count,
'comment_count': row.comment_count,
'author': row.author,
'category': row.category,
'engagement_score': row.engagement_score
}
for row in result
]
return jsonify({
'user_engagement': user_engagement,
'content_performance': content_performance
})
# 数据库性能优化示例
class QueryOptimizer:
"""查询优化器"""
@staticmethod
def get_posts_with_eager_loading(page=1, per_page=10):
"""使用预加载优化的文章查询"""
return Post.query.filter_by(is_published=True)\
.options(
joinedload(Post.author),
joinedload(Post.category),
selectinload(Post.tags),
selectinload(Post.comments).joinedload(Comment.author)
)\
.order_by(Post.published_at.desc())\
.paginate(page=page, per_page=per_page, error_out=False)
@staticmethod
def get_user_posts_count_efficient():
"""高效的用户文章数统计"""
return db.session.query(
User.id,
User.username,
func.count(Post.id).label('posts_count')
).outerjoin(Post, and_(Post.user_id == User.id, Post.is_published == True))\
.group_by(User.id, User.username)\
.order_by(func.count(Post.id).desc()).all()
@staticmethod
def bulk_update_view_counts(post_views):
"""批量更新文章浏览量"""
# post_views 是一个字典 {post_id: view_count}
cases = []
for post_id, view_count in post_views.items():
cases.append((post_id, view_count))
if cases:
stmt = case(
cases,
value=Post.id,
else_=Post.view_count
)
db.session.query(Post)\
.filter(Post.id.in_(post_views.keys()))\
.update({Post.view_count: stmt}, synchronize_session=False)
db.session.commit()
# 数据库迁移和维护
@app.cli.command()
def init_db():
"""初始化数据库"""
db.create_all()
print('Database initialized.')
@app.cli.command()
def seed_db():
"""填充测试数据"""
from faker import Faker
fake = Faker()
# 创建角色
admin_role = Role(name='admin', description='Administrator')
user_role = Role(name='user', description='Regular user')
db.session.add_all([admin_role, user_role])
# 创建分类
categories = [
Category(name='Technology', slug='technology', description='Tech related posts'),
Category(name='Science', slug='science', description='Science articles'),
Category(name='Business', slug='business', description='Business insights')
]
db.session.add_all(categories)
# 创建标签
tags = [
Tag(name='Python', slug='python'),
Tag(name='Flask', slug='flask'),
Tag(name='Web Development', slug='web-development'),
Tag(name='AI', slug='ai'),
Tag(name='Machine Learning', slug='machine-learning')
]
db.session.add_all(tags)
db.session.commit()
# 创建用户
for i in range(10):
user = User(
username=fake.user_name(),
email=fake.email(),
first_name=fake.first_name(),
last_name=fake.last_name(),
bio=fake.text(max_nb_chars=200),
is_active=True,
is_verified=fake.boolean(chance_of_getting_true=80)
)
user.set_password('password123')
user.add_role(user_role)
db.session.add(user)
db.session.commit()
# 创建文章
users = User.query.all()
for i in range(50):
post = Post(
title=fake.sentence(nb_words=6),
content=fake.text(max_nb_chars=2000),
is_published=fake.boolean(chance_of_getting_true=80),
view_count=fake.random_int(min=0, max=1000),
like_count=fake.random_int(min=0, max=100),
user_id=fake.random_element(users).id,
category_id=fake.random_element(categories).id
)
if post.is_published:
post.published_at = fake.date_time_between(start_date='-1y', end_date='now')
# 添加随机标签
post_tags = fake.random_elements(tags, length=fake.random_int(min=1, max=3), unique=True)
for tag in post_tags:
post.add_tag(tag)
db.session.add(post)
db.session.commit()
print('Database seeded with test data.')
安装和配置:
pip install Flask-Login
# app.py
from flask_login import LoginManager, login_user, logout_user, login_required, current_user
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = 'login'
login_manager.login_message = 'Please log in to access this page.'
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
@app.route('/login', methods=['GET', 'POST'])
def login():
if request.method == 'POST':
username = request.form['username']
password = request.form['password']
user = User.query.filter_by(username=username).first()
if user and user.check_password(password):
login_user(user)
flash('Login successful!', 'success')
return redirect(url_for('dashboard'))
else:
flash('Invalid username or password', 'error')
return render_template('login.html')
@app.route('/logout')
@login_required
def logout():
logout_user()
flash('You have been logged out.', 'info')
return redirect(url_for('index'))
@app.route('/dashboard')
@login_required
def dashboard():
return render_template('dashboard.html', user=current_user)
# api/users.py
from flask import Blueprint, jsonify, request
from models import db, User
api_bp = Blueprint('api', __name__, url_prefix='/api/v1')
@api_bp.route('/users', methods=['GET'])
def get_users():
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
users = User.query.paginate(
page=page, per_page=per_page, error_out=False
)
return jsonify({
'users': [user.to_dict() for user in users.items],
'total': users.total,
'pages': users.pages,
'current_page': page
})
@api_bp.route('/users/' , methods=['GET'])
def get_user(user_id):
user = User.query.get_or_404(user_id)
return jsonify(user.to_dict())
@api_bp.route('/users', methods=['POST'])
def create_user():
data = request.get_json()
if not data or not data.get('username') or not data.get('email'):
return jsonify({'error': 'Missing required fields'}), 400
# 检查用户是否已存在
if User.query.filter_by(username=data['username']).first():
return jsonify({'error': 'Username already exists'}), 400
if User.query.filter_by(email=data['email']).first():
return jsonify({'error': 'Email already exists'}), 400
user = User(
username=data['username'],
email=data['email'],
first_name=data.get('first_name'),
last_name=data.get('last_name')
)
if data.get('password'):
user.set_password(data['password'])
db.session.add(user)
db.session.commit()
return jsonify(user.to_dict()), 201
# 注册蓝图
app.register_blueprint(api_bp)
import os
from werkzeug.utils import secure_filename
from PIL import Image
UPLOAD_FOLDER = 'uploads'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/upload', methods=['GET', 'POST'])
def upload_file():
if request.method == 'POST':
if 'file' not in request.files:
flash('No file selected', 'error')
return redirect(request.url)
file = request.files['file']
if file.filename == '':
flash('No file selected', 'error')
return redirect(request.url)
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# 图片处理(可选)
if filename.lower().endswith(('png', 'jpg', 'jpeg')):
with Image.open(filepath) as img:
img.thumbnail((800, 600))
img.save(filepath)
flash('File uploaded successfully!', 'success')
return redirect(url_for('upload_file'))
return render_template('upload.html')
# auth.py - 完整的认证系统
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, session
from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required, current_user
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from werkzeug.security import generate_password_hash, check_password_hash
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
from datetime import datetime, timedelta
import secrets
import re
import hashlib
import pyotp
import qrcode
from io import BytesIO
import base64
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key-here'
app.config['SECURITY_PASSWORD_SALT'] = 'your-password-salt'
# 初始化扩展
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = 'auth.login'
login_manager.login_message = 'Please log in to access this page.'
login_manager.login_message_category = 'info'
# 速率限制
limiter = Limiter(
app,
key_func=get_remote_address,
default_limits=["200 per day", "50 per hour"]
)
# 用户模型(简化版,实际应使用数据库)
class User(UserMixin):
def __init__(self, id, username, email, password_hash, is_active=True,
is_verified=False, two_factor_enabled=False, two_factor_secret=None):
self.id = id
self.username = username
self.email = email
self.password_hash = password_hash
self.is_active = is_active
self.is_verified = is_verified
self.two_factor_enabled = two_factor_enabled
self.two_factor_secret = two_factor_secret
self.failed_login_attempts = 0
self.locked_until = None
self.last_login = None
self.created_at = datetime.utcnow()
def check_password(self, password):
return check_password_hash(self.password_hash, password)
def set_password(self, password):
self.password_hash = generate_password_hash(password)
def is_account_locked(self):
if self.locked_until and datetime.utcnow() < self.locked_until:
return True
return False
def lock_account(self, duration_minutes=30):
self.locked_until = datetime.utcnow() + timedelta(minutes=duration_minutes)
def unlock_account(self):
self.failed_login_attempts = 0
self.locked_until = None
def generate_2fa_secret(self):
self.two_factor_secret = pyotp.random_base32()
return self.two_factor_secret
def get_2fa_qr_code(self):
if not self.two_factor_secret:
self.generate_2fa_secret()
totp_uri = pyotp.totp.TOTP(self.two_factor_secret).provisioning_uri(
name=self.email,
issuer_name="Flask App"
)
qr = qrcode.QRCode(version=1, box_size=10, border=5)
qr.add_data(totp_uri)
qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white")
buffer = BytesIO()
img.save(buffer, format='PNG')
buffer.seek(0)
return base64.b64encode(buffer.getvalue()).decode()
def verify_2fa_token(self, token):
if not self.two_factor_secret:
return False
totp = pyotp.TOTP(self.two_factor_secret)
return totp.verify(token, valid_window=1)
# 用户存储(实际项目中应使用数据库)
users_db = {}
@login_manager.user_loader
def load_user(user_id):
return users_db.get(int(user_id))
# 密码强度验证
class PasswordValidator:
@staticmethod
def validate_password(password):
errors = []
if len(password) < 8:
errors.append("Password must be at least 8 characters long")
if not re.search(r'[A-Z]', password):
errors.append("Password must contain at least one uppercase letter")
if not re.search(r'[a-z]', password):
errors.append("Password must contain at least one lowercase letter")
if not re.search(r'\d', password):
errors.append("Password must contain at least one digit")
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
errors.append("Password must contain at least one special character")
# 检查常见密码
common_passwords = [
'password', '123456', 'password123', 'admin', 'qwerty',
'letmein', 'welcome', 'monkey', '1234567890'
]
if password.lower() in common_passwords:
errors.append("Password is too common")
return errors
# 邮件验证
class EmailVerification:
@staticmethod
def generate_token(email):
serializer = URLSafeTimedSerializer(app.config['SECRET_KEY'])
return serializer.dumps(email, salt=app.config['SECURITY_PASSWORD_SALT'])
@staticmethod
def verify_token(token, expiration=3600):
serializer = URLSafeTimedSerializer(app.config['SECRET_KEY'])
try:
email = serializer.loads(
token,
salt=app.config['SECURITY_PASSWORD_SALT'],
max_age=expiration
)
return email
except (SignatureExpired, BadSignature):
return None
# 密码重置
class PasswordReset:
@staticmethod
def generate_reset_token(user_id):
serializer = URLSafeTimedSerializer(app.config['SECRET_KEY'])
return serializer.dumps({'user_id': user_id, 'timestamp': datetime.utcnow().isoformat()})
@staticmethod
def verify_reset_token(token, expiration=3600):
serializer = URLSafeTimedSerializer(app.config['SECRET_KEY'])
try:
data = serializer.loads(token, max_age=expiration)
return data.get('user_id')
except (SignatureExpired, BadSignature):
return None
# 会话安全
class SessionSecurity:
@staticmethod
def generate_csrf_token():
if 'csrf_token' not in session:
session['csrf_token'] = secrets.token_hex(16)
return session['csrf_token']
@staticmethod
def validate_csrf_token(token):
return token and session.get('csrf_token') == token
@staticmethod
def regenerate_session():
"""重新生成会话ID以防止会话固定攻击"""
session.permanent = True
session.regenerate()
# 认证路由
@app.route('/register', methods=['GET', 'POST'])
@limiter.limit("5 per minute")
def register():
if request.method == 'POST':
username = request.form.get('username', '').strip()
email = request.form.get('email', '').strip().lower()
password = request.form.get('password', '')
confirm_password = request.form.get('confirm_password', '')
# 验证输入
errors = []
if not username or len(username) < 3:
errors.append("Username must be at least 3 characters long")
if not re.match(r'^[a-zA-Z0-9_]+$', username):
errors.append("Username can only contain letters, numbers, and underscores")
if not email or not re.match(r'^[^\s@]+@[^\s@]+\.[^\s@]+$', email):
errors.append("Please enter a valid email address")
if password != confirm_password:
errors.append("Passwords do not match")
password_errors = PasswordValidator.validate_password(password)
errors.extend(password_errors)
# 检查用户名和邮箱是否已存在
for user in users_db.values():
if user.username == username:
errors.append("Username already exists")
break
if user.email == email:
errors.append("Email already registered")
break
if errors:
for error in errors:
flash(error, 'error')
return render_template('auth/register.html')
# 创建用户
user_id = len(users_db) + 1
user = User(
id=user_id,
username=username,
email=email,
password_hash=generate_password_hash(password)
)
users_db[user_id] = user
# 发送验证邮件
verification_token = EmailVerification.generate_token(email)
# 这里应该发送邮件,现在只是模拟
verification_url = url_for('verify_email', token=verification_token, _external=True)
print(f"Verification URL: {verification_url}")
flash('Registration successful! Please check your email to verify your account.', 'success')
return redirect(url_for('login'))
return render_template('auth/register.html')
@app.route('/login', methods=['GET', 'POST'])
@limiter.limit("10 per minute")
def login():
if request.method == 'POST':
username_or_email = request.form.get('username_or_email', '').strip()
password = request.form.get('password', '')
remember_me = request.form.get('remember_me') == 'on'
two_factor_token = request.form.get('two_factor_token', '').strip()
if not username_or_email or not password:
flash('Please enter both username/email and password', 'error')
return render_template('auth/login.html')
# 查找用户
user = None
for u in users_db.values():
if u.username == username_or_email or u.email == username_or_email:
user = u
break
if not user:
flash('Invalid credentials', 'error')
return render_template('auth/login.html')
# 检查账户是否被锁定
if user.is_account_locked():
flash('Account is temporarily locked due to too many failed login attempts', 'error')
return render_template('auth/login.html')
# 验证密码
if not user.check_password(password):
user.failed_login_attempts += 1
if user.failed_login_attempts >= 5:
user.lock_account()
flash('Account locked due to too many failed login attempts', 'error')
else:
flash('Invalid credentials', 'error')
return render_template('auth/login.html')
# 检查账户状态
if not user.is_active:
flash('Account is deactivated', 'error')
return render_template('auth/login.html')
# 两步验证
if user.two_factor_enabled:
if not two_factor_token:
session['pending_user_id'] = user.id
return render_template('auth/two_factor.html')
if not user.verify_2fa_token(two_factor_token):
flash('Invalid two-factor authentication code', 'error')
return render_template('auth/two_factor.html')
# 登录成功
user.unlock_account()
user.last_login = datetime.utcnow()
login_user(user, remember=remember_me)
SessionSecurity.regenerate_session()
# 记录登录日志
login_log = {
'user_id': user.id,
'ip_address': request.remote_addr,
'user_agent': request.headers.get('User-Agent'),
'timestamp': datetime.utcnow(),
'success': True
}
print(f"Login log: {login_log}")
next_page = request.args.get('next')
if next_page and next_page.startswith('/'):
return redirect(next_page)
flash(f'Welcome back, {user.username}!', 'success')
return redirect(url_for('dashboard'))
return render_template('auth/login.html')
@app.route('/two-factor', methods=['POST'])
def verify_two_factor():
user_id = session.get('pending_user_id')
if not user_id:
return redirect(url_for('login'))
user = users_db.get(user_id)
if not user:
return redirect(url_for('login'))
token = request.form.get('two_factor_token', '').strip()
if user.verify_2fa_token(token):
session.pop('pending_user_id', None)
login_user(user)
SessionSecurity.regenerate_session()
flash(f'Welcome back, {user.username}!', 'success')
return redirect(url_for('dashboard'))
else:
flash('Invalid two-factor authentication code', 'error')
return render_template('auth/two_factor.html')
@app.route('/logout')
@login_required
def logout():
username = current_user.username
logout_user()
session.clear()
flash(f'Goodbye, {username}!', 'info')
return redirect(url_for('index'))
@app.route('/verify-email/' )
def verify_email(token):
email = EmailVerification.verify_token(token)
if not email:
flash('Invalid or expired verification link', 'error')
return redirect(url_for('login'))
# 查找用户并验证
for user in users_db.values():
if user.email == email:
user.is_verified = True
flash('Email verified successfully! You can now log in.', 'success')
return redirect(url_for('login'))
flash('User not found', 'error')
return redirect(url_for('login'))
@app.route('/forgot-password', methods=['GET', 'POST'])
@limiter.limit("3 per minute")
def forgot_password():
if request.method == 'POST':
email = request.form.get('email', '').strip().lower()
if not email:
flash('Please enter your email address', 'error')
return render_template('auth/forgot_password.html')
# 查找用户
user = None
for u in users_db.values():
if u.email == email:
user = u
break
if user:
reset_token = PasswordReset.generate_reset_token(user.id)
reset_url = url_for('reset_password', token=reset_token, _external=True)
# 这里应该发送邮件
print(f"Password reset URL: {reset_url}")
# 无论用户是否存在都显示相同消息(安全考虑)
flash('If an account with that email exists, a password reset link has been sent.', 'info')
return redirect(url_for('login'))
return render_template('auth/forgot_password.html')
@app.route('/reset-password/' , methods=['GET', 'POST'])
def reset_password(token):
user_id = PasswordReset.verify_reset_token(token)
if not user_id:
flash('Invalid or expired reset link', 'error')
return redirect(url_for('forgot_password'))
user = users_db.get(user_id)
if not user:
flash('User not found', 'error')
return redirect(url_for('forgot_password'))
if request.method == 'POST':
password = request.form.get('password', '')
confirm_password = request.form.get('confirm_password', '')
if password != confirm_password:
flash('Passwords do not match', 'error')
return render_template('auth/reset_password.html', token=token)
password_errors = PasswordValidator.validate_password(password)
if password_errors:
for error in password_errors:
flash(error, 'error')
return render_template('auth/reset_password.html', token=token)
user.set_password(password)
user.unlock_account() # 解锁账户
flash('Password reset successfully! You can now log in with your new password.', 'success')
return redirect(url_for('login'))
return render_template('auth/reset_password.html', token=token)
两步验证与安全设置
# security_settings.py
@app.route('/security-settings')
@login_required
def security_settings():
return render_template('auth/security_settings.html')
@app.route('/enable-2fa', methods=['GET', 'POST'])
@login_required
def enable_2fa():
if current_user.two_factor_enabled:
flash('Two-factor authentication is already enabled', 'info')
return redirect(url_for('security_settings'))
if request.method == 'POST':
token = request.form.get('token', '').strip()
if not current_user.two_factor_secret:
flash('Two-factor setup not initiated', 'error')
return redirect(url_for('enable_2fa'))
if current_user.verify_2fa_token(token):
current_user.two_factor_enabled = True
flash('Two-factor authentication enabled successfully!', 'success')
return redirect(url_for('security_settings'))
else:
flash('Invalid verification code', 'error')
# 生成二维码
if not current_user.two_factor_secret:
current_user.generate_2fa_secret()
qr_code = current_user.get_2fa_qr_code()
return render_template('auth/enable_2fa.html',
qr_code=qr_code,
secret=current_user.two_factor_secret)
@app.route('/disable-2fa', methods=['POST'])
@login_required
def disable_2fa():
password = request.form.get('password', '')
if not current_user.check_password(password):
flash('Invalid password', 'error')
return redirect(url_for('security_settings'))
current_user.two_factor_enabled = False
current_user.two_factor_secret = None
flash('Two-factor authentication disabled', 'info')
return redirect(url_for('security_settings'))
@app.route('/change-password', methods=['GET', 'POST'])
@login_required
def change_password():
if request.method == 'POST':
current_password = request.form.get('current_password', '')
new_password = request.form.get('new_password', '')
confirm_password = request.form.get('confirm_password', '')
if not current_user.check_password(current_password):
flash('Current password is incorrect', 'error')
return render_template('auth/change_password.html')
if new_password != confirm_password:
flash('New passwords do not match', 'error')
return render_template('auth/change_password.html')
password_errors = PasswordValidator.validate_password(new_password)
if password_errors:
for error in password_errors:
flash(error, 'error')
return render_template('auth/change_password.html')
current_user.set_password(new_password)
flash('Password changed successfully!', 'success')
return redirect(url_for('security_settings'))
return render_template('auth/change_password.html')
# 活动会话管理
@app.route('/active-sessions')
@login_required
def active_sessions():
# 这里应该从数据库获取用户的活动会话
sessions = [
{
'id': 'session_1',
'ip_address': '192.168.1.100',
'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
'location': 'New York, US',
'last_activity': datetime.utcnow() - timedelta(minutes=5),
'is_current': True
},
{
'id': 'session_2',
'ip_address': '192.168.1.101',
'user_agent': 'Mozilla/5.0 (iPhone; CPU iPhone OS 14_0 like Mac OS X)',
'location': 'New York, US',
'last_activity': datetime.utcnow() - timedelta(hours=2),
'is_current': False
}
]
return render_template('auth/active_sessions.html', sessions=sessions)
@app.route('/revoke-session/' , methods=['POST'])
@login_required
def revoke_session(session_id):
# 这里应该从数据库删除指定的会话
flash('Session revoked successfully', 'success')
return redirect(url_for('active_sessions'))
# 登录历史
@app.route('/login-history')
@login_required
def login_history():
# 这里应该从数据库获取用户的登录历史
history = [
{
'timestamp': datetime.utcnow() - timedelta(minutes=5),
'ip_address': '192.168.1.100',
'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64)',
'location': 'New York, US',
'success': True
},
{
'timestamp': datetime.utcnow() - timedelta(hours=2),
'ip_address': '192.168.1.101',
'user_agent': 'Mozilla/5.0 (iPhone; CPU iPhone OS 14_0)',
'location': 'New York, US',
'success': True
},
{
'timestamp': datetime.utcnow() - timedelta(days=1),
'ip_address': '10.0.0.1',
'user_agent': 'Unknown',
'location': 'Unknown',
'success': False
}
]
return render_template('auth/login_history.html', history=history)
from flask import Flask, jsonify, request
from flask_sqlalchemy import SQLAlchemy
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///api.db'
db = SQLAlchemy(app)
class Task(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(100), nullable=False)
description = db.Column(db.Text)
completed = db.Column(db.Boolean, default=False)
def to_dict(self):
return {
'id': self.id,
'title': self.title,
'description': self.description,
'completed': self.completed
}
# GET /api/tasks - 获取所有任务
@app.route('/api/tasks', methods=['GET'])
def get_tasks():
tasks = Task.query.all()
return jsonify([task.to_dict() for task in tasks])
# POST /api/tasks - 创建新任务
@app.route('/api/tasks', methods=['POST'])
def create_task():
data = request.get_json()
if not data or 'title' not in data:
return jsonify({'error': 'Title is required'}), 400
task = Task(
title=data['title'],
description=data.get('description', ''),
completed=data.get('completed', False)
)
db.session.add(task)
db.session.commit()
return jsonify(task.to_dict()), 201
# PUT /api/tasks/ - 更新任务
@app.route('/api/tasks/' , methods=['PUT'])
def update_task(task_id):
task = Task.query.get_or_404(task_id)
data = request.get_json()
task.title = data.get('title', task.title)
task.description = data.get('description', task.description)
task.completed = data.get('completed', task.completed)
db.session.commit()
return jsonify(task.to_dict())
# DELETE /api/tasks/ - 删除任务
@app.route('/api/tasks/' , methods=['DELETE'])
def delete_task(task_id):
task = Task.query.get_or_404(task_id)
db.session.delete(task)
db.session.commit()
return '', 204
让我们构建一个完整的博客系统,包含用户注册、登录、文章发布、评论等功能。
flask-blog/
├── app.py # 应用入口
├── config.py # 配置文件
├── models.py # 数据模型
├── forms.py # 表单定义
├── requirements.txt # 依赖包
├── static/ # 静态文件
│ ├── css/
│ ├── js/
│ └── uploads/
├── templates/ # 模板文件
│ ├── base.html
│ ├── index.html
│ ├── login.html
│ ├── register.html
│ ├── post.html
│ └── create_post.html
└── migrations/ # 数据库迁移文件
1. 用户管理系统:
# forms.py
from flask_wtf import FlaskForm
from wtforms import StringField, PasswordField, TextAreaField, BooleanField, SubmitField
from wtforms.validators import DataRequired, Length, Email, EqualTo, ValidationError
from models import User
class LoginForm(FlaskForm):
username = StringField('Username', validators=[DataRequired()])
password = PasswordField('Password', validators=[DataRequired()])
remember_me = BooleanField('Remember Me')
submit = SubmitField('Sign In')
class RegistrationForm(FlaskForm):
username = StringField('Username', validators=[DataRequired(), Length(min=4, max=20)])
email = StringField('Email', validators=[DataRequired(), Email()])
password = PasswordField('Password', validators=[DataRequired(), Length(min=8)])
password2 = PasswordField('Repeat Password', validators=[
DataRequired(), EqualTo('password')])
submit = SubmitField('Register')
def validate_username(self, username):
user = User.query.filter_by(username=username.data).first()
if user is not None:
raise ValidationError('Please use a different username.')
def validate_email(self, email):
user = User.query.filter_by(email=email.data).first()
if user is not None:
raise ValidationError('Please use a different email address.')
class PostForm(FlaskForm):
title = StringField('Title', validators=[DataRequired(), Length(min=1, max=200)])
content = TextAreaField('Content', validators=[DataRequired()],
render_kw={"rows": 10})
submit = SubmitField('Publish')
2. 路由和视图函数:
# app.py
from flask import Flask, render_template, flash, redirect, url_for, request
from flask_login import login_user, logout_user, current_user, login_required
from models import db, User, Post
from forms import LoginForm, RegistrationForm, PostForm
@app.route('/')
def index():
page = request.args.get('page', 1, type=int)
posts = Post.query.filter_by(is_published=True).order_by(
Post.created_at.desc()).paginate(
page=page, per_page=5, error_out=False)
return render_template('index.html', posts=posts)
@app.route('/login', methods=['GET', 'POST'])
def login():
if current_user.is_authenticated:
return redirect(url_for('index'))
form = LoginForm()
if form.validate_on_submit():
user = User.query.filter_by(username=form.username.data).first()
if user is None or not user.check_password(form.password.data):
flash('Invalid username or password')
return redirect(url_for('login'))
login_user(user, remember=form.remember_me.data)
next_page = request.args.get('next')
if not next_page or url_parse(next_page).netloc != '':
next_page = url_for('index')
return redirect(next_page)
return render_template('login.html', title='Sign In', form=form)
@app.route('/register', methods=['GET', 'POST'])
def register():
if current_user.is_authenticated:
return redirect(url_for('index'))
form = RegistrationForm()
if form.validate_on_submit():
user = User(username=form.username.data, email=form.email.data)
user.set_password(form.password.data)
db.session.add(user)
db.session.commit()
flash('Congratulations, you are now a registered user!')
return redirect(url_for('login'))
return render_template('register.html', title='Register', form=form)
@app.route('/create_post', methods=['GET', 'POST'])
@login_required
def create_post():
form = PostForm()
if form.validate_on_submit():
post = Post(
title=form.title.data,
content=form.content.data,
author=current_user,
is_published=True
)
db.session.add(post)
db.session.commit()
flash('Your post has been published!')
return redirect(url_for('index'))
return render_template('create_post.html', title='Create Post', form=form)
@app.route('/post/' )
def post(id):
post = Post.query.get_or_404(id)
return render_template('post.html', post=post)
3. 模板示例:
{% extends "base.html" %}
{% block content %}
<div class="row">
<div class="col-md-8">
<h1>Latest Postsh1>
{% for post in posts.items %}
<div class="card mb-4">
<div class="card-body">
<h5 class="card-title">
<a href="{{ url_for('post', id=post.id) }}">{{ post.title }}a>
h5>
<p class="card-text">{{ post.content[:200] }}...p>
<small class="text-muted">
By {{ post.author.username }} on {{ post.created_at.strftime('%Y-%m-%d') }}
small>
div>
div>
{% endfor %}
<nav aria-label="Page navigation">
<ul class="pagination">
{% if posts.has_prev %}
<li class="page-item">
<a class="page-link" href="{{ url_for('index', page=posts.prev_num) }}">Previousa>
li>
{% endif %}
{% for page_num in posts.iter_pages() %}
{% if page_num %}
{% if page_num != posts.page %}
<li class="page-item">
<a class="page-link" href="{{ url_for('index', page=page_num) }}">{{ page_num }}a>
li>
{% else %}
<li class="page-item active">
<span class="page-link">{{ page_num }}span>
li>
{% endif %}
{% endif %}
{% endfor %}
{% if posts.has_next %}
<li class="page-item">
<a class="page-link" href="{{ url_for('index', page=posts.next_num) }}">Nexta>
li>
{% endif %}
ul>
nav>
div>
<div class="col-md-4">
<div class="card">
<div class="card-header">
<h5>Quick Actionsh5>
div>
<div class="card-body">
{% if current_user.is_authenticated %}
<a href="{{ url_for('create_post') }}" class="btn btn-primary btn-sm">Write Posta>
<a href="{{ url_for('logout') }}" class="btn btn-outline-secondary btn-sm">Logouta>
{% else %}
<a href="{{ url_for('login') }}" class="btn btn-primary btn-sm">Logina>
<a href="{{ url_for('register') }}" class="btn btn-outline-primary btn-sm">Registera>
{% endif %}
div>
div>
div>
div>
{% endblock %}
flask-blog/
├── app.py # 主应用文件
├── models.py # 数据模型
├── forms.py # 表单定义
├── config.py # 配置文件
├── requirements.txt # 依赖列表
├── static/ # 静态文件
│ ├── css/
│ ├── js/
│ └── images/
└── templates/ # 模板文件
├── base.html
├── index.html
├── post.html
└── admin/
⚙️ 核心功能实现
# models.py
from flask_sqlalchemy import SQLAlchemy
from datetime import datetime
db = SQLAlchemy()
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False)
email = db.Column(db.String(120), unique=True, nullable=False)
password_hash = db.Column(db.String(128))
posts = db.relationship('Post', backref='author', lazy=True)
class Post(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(100), nullable=False)
content = db.Column(db.Text, nullable=False)
date_posted = db.Column(db.DateTime, default=datetime.utcnow)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
class Category(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), nullable=False)
description = db.Column(db.Text)
扩展名 | 功能 | 安装命令 |
---|---|---|
Flask-SQLAlchemy | 数据库ORM | pip install Flask-SQLAlchemy |
Flask-Login | 用户会话管理 | pip install Flask-Login |
Flask-WTF | 表单处理和验证 | pip install Flask-WTF |
Flask-Mail | 邮件发送 | pip install Flask-Mail |
Flask-Admin | 管理后台 | pip install Flask-Admin |
Flask-RESTful | REST API开发 | pip install Flask-RESTful |
推荐的Flask项目结构:
flask-app/
├── app/
│ ├── __init__.py # 应用工厂
│ ├── models.py # 数据模型
│ ├── forms.py # 表单定义
│ ├── views/ # 视图蓝图
│ │ ├── __init__.py
│ │ ├── auth.py # 认证相关
│ │ ├── main.py # 主要功能
│ │ └── api.py # API接口
│ ├── templates/ # 模板文件
│ └── static/ # 静态文件
├── migrations/ # 数据库迁移
├── tests/ # 测试文件
├── config.py # 配置文件
├── requirements.txt # 依赖包
└── run.py # 应用启动文件
环境配置分离:
# config.py
import os
class Config:
SECRET_KEY = os.environ.get('SECRET_KEY') or 'hard-to-guess-string'
SQLALCHEMY_TRACK_MODIFICATIONS = False
@staticmethod
def init_app(app):
pass
class DevelopmentConfig(Config):
DEBUG = True
SQLALCHEMY_DATABASE_URI = os.environ.get('DEV_DATABASE_URL') or \
'sqlite:///' + os.path.join(os.path.dirname(__file__), 'data-dev.sqlite')
class ProductionConfig(Config):
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \
'sqlite:///' + os.path.join(os.path.dirname(__file__), 'data.sqlite')
config = {
'development': DevelopmentConfig,
'production': ProductionConfig,
'default': DevelopmentConfig
}
1. CSRF保护:
from flask_wtf.csrf import CSRFProtect
csrf = CSRFProtect(app)
2. 密码安全:
from werkzeug.security import generate_password_hash, check_password_hash
# 存储密码
password_hash = generate_password_hash('user_password')
# 验证密码
is_valid = check_password_hash(password_hash, 'user_password')
3. 输入验证:
from flask import request, escape
# 转义用户输入
user_input = escape(request.form['user_input'])
# 使用WTForms验证
from wtforms.validators import DataRequired, Length, Email
class SafeForm(FlaskForm):
email = StringField('Email', validators=[DataRequired(), Email()])
content = TextAreaField('Content', validators=[Length(max=1000)])
1. 数据库查询优化:
# 使用eager loading避免N+1查询
users = User.query.options(db.joinedload(User.posts)).all()
# 分页查询
posts = Post.query.paginate(page=1, per_page=20, error_out=False)
# 索引优化
class User(db.Model):
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
2. 缓存策略:
from flask_caching import Cache
cache = Cache(app, config={'CACHE_TYPE': 'simple'})
@app.route('/expensive-operation')
@cache.cached(timeout=300) # 缓存5分钟
def expensive_operation():
# 耗时操作
return render_template('result.html')
单元测试示例:
# tests/test_models.py
import unittest
from app import create_app, db
from app.models import User
class UserModelTestCase(unittest.TestCase):
def setUp(self):
self.app = create_app('testing')
self.app_context = self.app.app_context()
self.app_context.push()
db.create_all()
def tearDown(self):
db.session.remove()
db.drop_all()
self.app_context.pop()
def test_password_hashing(self):
u = User(username='test')
u.set_password('password')
self.assertFalse(u.check_password('wrong'))
self.assertTrue(u.check_password('password'))
if __name__ == '__main__':
unittest.main()
import logging
from logging.handlers import RotatingFileHandler
if not app.debug:
if not os.path.exists('logs'):
os.mkdir('logs')
file_handler = RotatingFileHandler('logs/app.log', maxBytes=10240, backupCount=10)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]'
))
file_handler.setLevel(logging.INFO)
app.logger.addHandler(file_handler)
app.logger.setLevel(logging.INFO)
app.logger.info('Flask application startup')
扩展名称 | 功能描述 | 安装命令 |
---|---|---|
Flask-SQLAlchemy | 数据库ORM | pip install Flask-SQLAlchemy |
Flask-Login | 用户会话管理 | pip install Flask-Login |
Flask-WTF | 表单处理和CSRF保护 | pip install Flask-WTF |
Flask-Mail | 邮件发送 | pip install Flask-Mail |
Flask-Migrate | 数据库迁移 | pip install Flask-Migrate |
Flask-Caching | 缓存支持 | pip install Flask-Caching |
Flask-CORS | 跨域资源共享 | pip install Flask-CORS |
Flask-JWT-Extended | JWT认证 | pip install Flask-JWT-Extended |
1. 生产环境部署:
# 使用Gunicorn
pip install gunicorn
gunicorn -w 4 -b 0.0.0.0:8000 app:app
# 使用uWSGI
pip install uwsgi
uwsgi --http :8000 --module app:app
2. Docker部署:
FROM python:3.9-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
EXPOSE 5000
CMD ["gunicorn", "-b", "0.0.0.0:5000", "app:app"]
恭喜完成Flask学习之旅!
从基础语法到实战项目,你已经掌握了Flask Web开发的核心技能。
继续探索,持续进步!
large-flask-app/
├── app/
│ ├── __init__.py
│ ├── models/
│ ├── views/
│ ├── templates/
│ └── static/
├── migrations/
├── tests/
├── config.py
├── requirements.txt
└── run.py
# 1. 开启调试模式
app.run(debug=True)
# 2. 使用日志
import logging
logging.basicConfig(level=logging.DEBUG)
app.logger.debug('This is a debug message')
# 3. 使用Flask-DebugToolbar
from flask_debugtoolbar import DebugToolbarExtension
toolbar = DebugToolbarExtension(app)
# 4. 自定义错误页面
@app.errorhandler(404)
def not_found(error):
return render_template('404.html'), 404
@app.errorhandler(500)
def internal_error(error):
return render_template('500.html'), 500
# test_app.py
import unittest
from app import app, db
from app.models import User, Post
class FlaskTestCase(unittest.TestCase):
def setUp(self):
app.config['TESTING'] = True
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///:memory:'
self.app = app.test_client()
with app.app_context():
db.create_all()
def tearDown(self):
with app.app_context():
db.drop_all()
def test_home_page(self):
response = self.app.get('/')
self.assertEqual(response.status_code, 200)
self.assertIn(b'Welcome', response.data)
def test_user_registration(self):
response = self.app.post('/register', data={
'username': 'testuser',
'email': '[email protected]',
'password': 'testpass'
})
self.assertEqual(response.status_code, 302) # Redirect after success
if __name__ == '__main__':
unittest.main()
# Dockerfile
FROM python:3.9-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
EXPOSE 5000
CMD ["gunicorn", "--bind", "0.0.0.0:5000", "app:app"]
# docker-compose.yml
version: '3.8'
services:
web:
build: .
ports:
- "5000:5000"
environment:
- FLASK_ENV=production
- DATABASE_URL=postgresql://user:pass@db:5432/flaskapp
depends_on:
- db
- redis
db:
image: postgres:13
environment:
POSTGRES_DB: flaskapp
POSTGRES_USER: user
POSTGRES_PASSWORD: pass
volumes:
- postgres_data:/var/lib/postgresql/data
redis:
image: redis:6-alpine
volumes:
postgres_data:
# 1. 安装Heroku CLI
# 2. 登录Heroku
heroku login
# 3. 创建应用
heroku create your-app-name
# 4. 添加数据库
heroku addons:create heroku-postgresql:hobby-dev
# 5. 设置环境变量
heroku config:set FLASK_ENV=production
# 6. 部署
git push heroku main
# Procfile
web: gunicorn app:app
Flask: 轻量级,学习简单,适合小到中型项目
FastAPI: 现代化,异步支持,API开发优秀
Django: 功能完整,适合大型项目,学习曲线陡峭
️ 开始您的Flask学习之旅吧!
记住:最好的学习方式就是动手实践!
回到顶部 ⬆️