169 lines
5.9 KiB
Python
169 lines
5.9 KiB
Python
"""认证模块 — Flask-Login 集成"""
|
||
|
||
from flask import Blueprint, render_template, request, redirect, url_for, flash
|
||
from flask_login import LoginManager, login_user, logout_user, login_required, current_user
|
||
from werkzeug.security import check_password_hash
|
||
from app.models import get_user_by_username, insert_log
|
||
|
||
auth_bp = Blueprint("auth", __name__)
|
||
login_manager = LoginManager()
|
||
login_manager.login_view = "auth.login"
|
||
login_manager.login_message = "请先登录"
|
||
|
||
|
||
class User:
|
||
"""Flask-Login 用户对象"""
|
||
def __init__(self, user_dict):
|
||
self.id = user_dict["id"]
|
||
self.username = user_dict["username"]
|
||
self.role = user_dict["role"]
|
||
self.is_active = bool(user_dict.get("is_active", 1))
|
||
|
||
@property
|
||
def is_authenticated(self):
|
||
return True
|
||
|
||
def get_id(self):
|
||
return str(self.id)
|
||
|
||
|
||
@login_manager.user_loader
|
||
def load_user(user_id):
|
||
conn = __import__("app.models", fromlist=["get_conn"]).get_conn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute("SELECT * FROM tb_user WHERE id=%s", (int(user_id),))
|
||
row = cur.fetchone()
|
||
finally:
|
||
conn.close()
|
||
return User(row) if row else None
|
||
|
||
|
||
def init_auth(app):
|
||
login_manager.init_app(app)
|
||
|
||
# analyst 角色:全局路由白名单拦截
|
||
ANALYST_ALLOWED = {
|
||
"auth.login", "auth.logout", "auth.change_password",
|
||
"test_data.test_data_page",
|
||
"test_data.api_test_data",
|
||
"test_data.api_chart_data",
|
||
"test_data.api_export",
|
||
"test_data.api_delete", # 自身有 inline 角色检查
|
||
}
|
||
|
||
@app.before_request
|
||
def _restrict_analyst():
|
||
if current_user.is_authenticated and current_user.role == "analyst":
|
||
ep = request.endpoint or ""
|
||
if ep not in ANALYST_ALLOWED and not ep.startswith("static"):
|
||
flash("当前角色为 analyst,仅可访问测试数据")
|
||
return redirect(url_for("test_data.test_data_page"))
|
||
|
||
|
||
# ─── 装饰器 ────────────────────────────────────────────────────────
|
||
|
||
def admin_required(f):
|
||
"""要求 admin 角色(仅 admin,manager 不可通过)"""
|
||
from functools import wraps
|
||
@wraps(f)
|
||
@login_required
|
||
def wrapper(*args, **kwargs):
|
||
if current_user.role != "admin":
|
||
return "权限不足", 403
|
||
return f(*args, **kwargs)
|
||
return wrapper
|
||
|
||
|
||
def privileged_required(f):
|
||
"""要求 admin 或 manager 角色"""
|
||
from functools import wraps
|
||
@wraps(f)
|
||
@login_required
|
||
def wrapper(*args, **kwargs):
|
||
if current_user.role not in ("admin", "manager"):
|
||
return "权限不足", 403
|
||
return f(*args, **kwargs)
|
||
return wrapper
|
||
|
||
|
||
# ─── 登录 / 登出 ────────────────────────────────────────────────────
|
||
|
||
@auth_bp.route("/login", methods=["GET", "POST"])
|
||
def login():
|
||
if request.method == "POST":
|
||
username = request.form.get("username", "").strip()
|
||
password = request.form.get("password", "")
|
||
user_dict = get_user_by_username(username)
|
||
ip = request.remote_addr or ""
|
||
|
||
if user_dict and user_dict.get("is_active") and check_password_hash(user_dict["password_hash"], password):
|
||
user = User(user_dict)
|
||
login_user(user)
|
||
insert_log(user.id, user.username, "login", ip=ip, result="ok")
|
||
next_page = request.args.get("next")
|
||
return redirect(next_page or url_for("devices.index"))
|
||
else:
|
||
insert_log(0, username, "login", detail="密码错误或账号禁用", ip=ip, result="error")
|
||
flash("用户名或密码错误")
|
||
return render_template("login.html")
|
||
|
||
|
||
@auth_bp.route("/logout")
|
||
@login_required
|
||
def logout():
|
||
insert_log(current_user.id, current_user.username, "logout",
|
||
ip=request.remote_addr or "", result="ok")
|
||
logout_user()
|
||
return redirect(url_for("auth.login"))
|
||
|
||
|
||
@auth_bp.route("/change-password", methods=["GET", "POST"])
|
||
@login_required
|
||
def change_password():
|
||
"""所有用户自行修改密码"""
|
||
if request.method == "POST":
|
||
old_password = request.form.get("old_password", "")
|
||
new_password = request.form.get("new_password", "").strip()
|
||
confirm_password = request.form.get("confirm_password", "")
|
||
|
||
if not old_password or not new_password:
|
||
flash("所有字段都不能为空")
|
||
return render_template("change_password.html")
|
||
|
||
if len(new_password) < 6:
|
||
flash("新密码至少6位")
|
||
return render_template("change_password.html")
|
||
|
||
if new_password != confirm_password:
|
||
flash("两次输入的新密码不一致")
|
||
return render_template("change_password.html")
|
||
|
||
# 验证旧密码
|
||
from app.models import get_conn, get_user_by_username
|
||
from werkzeug.security import generate_password_hash
|
||
user_dict = get_user_by_username(current_user.username)
|
||
if not user_dict or not check_password_hash(user_dict["password_hash"], old_password):
|
||
flash("原密码错误")
|
||
return render_template("change_password.html")
|
||
|
||
# 更新密码
|
||
conn = get_conn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"UPDATE tb_user SET password_hash=%s WHERE id=%s",
|
||
(generate_password_hash(new_password), current_user.id),
|
||
)
|
||
conn.commit()
|
||
finally:
|
||
conn.close()
|
||
|
||
insert_log(current_user.id, current_user.username, "update",
|
||
target="self", detail="修改个人密码",
|
||
result="ok", ip=request.remote_addr or "")
|
||
flash("密码修改成功")
|
||
return redirect(url_for("devices.index"))
|
||
|
||
return render_template("change_password.html")
|