mirror of
https://github.com/zebrajr/ollama-webui.git
synced 2026-01-15 12:15:13 +00:00
enh/refac: kb pagination
This commit is contained in:
@@ -104,6 +104,11 @@ class FileUpdateForm(BaseModel):
|
||||
meta: Optional[dict] = None
|
||||
|
||||
|
||||
class FileListResponse(BaseModel):
|
||||
items: list[FileModel]
|
||||
total: int
|
||||
|
||||
|
||||
class FilesTable:
|
||||
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
import uuid
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.models.files import (
|
||||
@@ -30,6 +31,8 @@ from sqlalchemy import (
|
||||
)
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.utils.db.access_control import has_permission
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -145,6 +148,11 @@ class FileUserResponse(FileModelResponse):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class KnowledgeListResponse(BaseModel):
|
||||
items: list[KnowledgeUserModel]
|
||||
total: int
|
||||
|
||||
|
||||
class KnowledgeFileListResponse(BaseModel):
|
||||
items: list[FileUserResponse]
|
||||
total: int
|
||||
@@ -177,12 +185,13 @@ class KnowledgeTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
|
||||
def get_knowledge_bases(
|
||||
self, skip: int = 0, limit: int = 30
|
||||
) -> list[KnowledgeUserModel]:
|
||||
with get_db() as db:
|
||||
all_knowledge = (
|
||||
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
||||
)
|
||||
|
||||
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
|
||||
@@ -201,6 +210,126 @@ class KnowledgeTable:
|
||||
)
|
||||
return knowledge_bases
|
||||
|
||||
def search_knowledge_bases(
|
||||
self, user_id: str, filter: dict, skip: int = 0, limit: int = 30
|
||||
) -> KnowledgeListResponse:
|
||||
try:
|
||||
with get_db() as db:
|
||||
query = db.query(Knowledge, User).outerjoin(
|
||||
User, User.id == Knowledge.user_id
|
||||
)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Knowledge.name.ilike(f"%{query_key}%"),
|
||||
Knowledge.description.ilike(f"%{query_key}%"),
|
||||
)
|
||||
)
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
query = query.filter(Knowledge.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
query = query.filter(Knowledge.user_id != user_id)
|
||||
|
||||
query = has_permission(db, Knowledge, query, filter)
|
||||
|
||||
query = query.order_by(Knowledge.updated_at.desc())
|
||||
|
||||
total = query.count()
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
items = query.all()
|
||||
|
||||
knowledge_bases = []
|
||||
for knowledge_base, user in items:
|
||||
knowledge_bases.append(
|
||||
KnowledgeUserModel.model_validate(
|
||||
{
|
||||
**KnowledgeModel.model_validate(
|
||||
knowledge_base
|
||||
).model_dump(),
|
||||
"user": (
|
||||
UserModel.model_validate(user).model_dump()
|
||||
if user
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return KnowledgeListResponse(items=knowledge_bases, total=total)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return KnowledgeListResponse(items=[], total=0)
|
||||
|
||||
def search_knowledge_files(
|
||||
self, filter: dict, skip: int = 0, limit: int = 30
|
||||
) -> KnowledgeFileListResponse:
|
||||
"""
|
||||
Scalable version: search files across all knowledge bases the user has
|
||||
READ access to, without loading all KBs or using large IN() lists.
|
||||
"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Base query: join Knowledge → KnowledgeFile → File
|
||||
query = (
|
||||
db.query(File, User)
|
||||
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
|
||||
.join(Knowledge, KnowledgeFile.knowledge_id == Knowledge.id)
|
||||
.outerjoin(User, User.id == KnowledgeFile.user_id)
|
||||
)
|
||||
|
||||
# Apply access-control directly to the joined query
|
||||
# This makes the database handle filtering, even with 10k+ KBs
|
||||
query = has_permission(db, Knowledge, query, filter)
|
||||
|
||||
# Apply filename search
|
||||
if filter:
|
||||
q = filter.get("query")
|
||||
if q:
|
||||
query = query.filter(File.filename.ilike(f"%{q}%"))
|
||||
|
||||
# Order by file changes
|
||||
query = query.order_by(File.updated_at.desc())
|
||||
|
||||
# Count before pagination
|
||||
total = query.count()
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
rows = query.all()
|
||||
|
||||
items = []
|
||||
for file, user in rows:
|
||||
items.append(
|
||||
FileUserResponse(
|
||||
**FileModel.model_validate(file).model_dump(),
|
||||
user=(
|
||||
UserResponse(
|
||||
**UserModel.model_validate(user).model_dump()
|
||||
)
|
||||
if user
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return KnowledgeFileListResponse(items=items, total=total)
|
||||
|
||||
except Exception as e:
|
||||
print("search_knowledge_files error:", e)
|
||||
return KnowledgeFileListResponse(items=[], total=0)
|
||||
|
||||
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
|
||||
knowledge = self.get_knowledge_by_id(id)
|
||||
if not knowledge:
|
||||
|
||||
@@ -39,7 +39,6 @@ from open_webui.models.knowledge import Knowledges
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
|
||||
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
|
||||
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||
from open_webui.routers.audio import transcribe
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
import logging
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.knowledge import (
|
||||
KnowledgeFileListResponse,
|
||||
Knowledges,
|
||||
@@ -40,53 +41,115 @@ router = APIRouter()
|
||||
# getKnowledgeBases
|
||||
############################
|
||||
|
||||
PAGE_ITEM_COUNT = 30
|
||||
|
||||
|
||||
class KnowledgeAccessResponse(KnowledgeUserResponse):
|
||||
write_access: Optional[bool] = False
|
||||
|
||||
|
||||
@router.get("/", response_model=list[KnowledgeAccessResponse])
|
||||
async def get_knowledge(user=Depends(get_verified_user)):
|
||||
# Return knowledge bases with read access
|
||||
knowledge_bases = []
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||
else:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
|
||||
|
||||
return [
|
||||
KnowledgeAccessResponse(
|
||||
**knowledge_base.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge_base.id),
|
||||
write_access=(
|
||||
user.id == knowledge_base.user_id
|
||||
or has_access(user.id, "write", knowledge_base.access_control)
|
||||
),
|
||||
)
|
||||
for knowledge_base in knowledge_bases
|
||||
]
|
||||
class KnowledgeAccessListResponse(BaseModel):
|
||||
items: list[KnowledgeAccessResponse]
|
||||
total: int
|
||||
|
||||
|
||||
@router.get("/list", response_model=list[KnowledgeAccessResponse])
|
||||
async def get_knowledge_list(user=Depends(get_verified_user)):
|
||||
# Return knowledge bases with write access
|
||||
knowledge_bases = []
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases()
|
||||
else:
|
||||
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
|
||||
@router.get("/", response_model=KnowledgeAccessListResponse)
|
||||
async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified_user)):
|
||||
page = max(page, 1)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
skip = (page - 1) * limit
|
||||
|
||||
return [
|
||||
KnowledgeAccessResponse(
|
||||
**knowledge_base.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge_base.id),
|
||||
write_access=(
|
||||
user.id == knowledge_base.user_id
|
||||
or has_access(user.id, "write", knowledge_base.access_control)
|
||||
),
|
||||
)
|
||||
for knowledge_base in knowledge_bases
|
||||
]
|
||||
filter = {}
|
||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
groups = Groups.get_groups_by_member_id(user.id)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
filter["user_id"] = user.id
|
||||
|
||||
result = Knowledges.search_knowledge_bases(
|
||||
user.id, filter=filter, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
return KnowledgeAccessListResponse(
|
||||
items=[
|
||||
KnowledgeAccessResponse(
|
||||
**knowledge_base.model_dump(),
|
||||
write_access=(
|
||||
user.id == knowledge_base.user_id
|
||||
or has_access(user.id, "write", knowledge_base.access_control)
|
||||
),
|
||||
)
|
||||
for knowledge_base in result.items
|
||||
],
|
||||
total=result.total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/search", response_model=KnowledgeAccessListResponse)
|
||||
async def search_knowledge_bases(
|
||||
query: Optional[str] = None,
|
||||
view_option: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
page = max(page, 1)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if view_option:
|
||||
filter["view_option"] = view_option
|
||||
|
||||
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
groups = Groups.get_groups_by_member_id(user.id)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
filter["user_id"] = user.id
|
||||
|
||||
result = Knowledges.search_knowledge_bases(
|
||||
user.id, filter=filter, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
return KnowledgeAccessListResponse(
|
||||
items=[
|
||||
KnowledgeAccessResponse(
|
||||
**knowledge_base.model_dump(),
|
||||
write_access=(
|
||||
user.id == knowledge_base.user_id
|
||||
or has_access(user.id, "write", knowledge_base.access_control)
|
||||
),
|
||||
)
|
||||
for knowledge_base in result.items
|
||||
],
|
||||
total=result.total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/search/files", response_model=KnowledgeFileListResponse)
|
||||
async def search_knowledge_files(
|
||||
query: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
page = max(page, 1)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
|
||||
groups = Groups.get_groups_by_member_id(user.id)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
filter["user_id"] = user.id
|
||||
|
||||
return Knowledges.search_knowledge_files(filter=filter, skip=skip, limit=limit)
|
||||
|
||||
|
||||
############################
|
||||
@@ -198,7 +261,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||
|
||||
|
||||
class KnowledgeFilesResponse(KnowledgeResponse):
|
||||
files: list[FileMetadataResponse]
|
||||
files: Optional[list[FileMetadataResponse]] = None
|
||||
write_access: Optional[bool] = False
|
||||
|
||||
|
||||
@@ -215,7 +278,6 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
|
||||
write_access=(
|
||||
user.id == knowledge.user_id
|
||||
or has_access(user.id, "write", knowledge.access_control)
|
||||
|
||||
130
backend/open_webui/utils/db/access_control.py
Normal file
130
backend/open_webui/utils/db/access_control.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
|
||||
from sqlalchemy import or_, func, select, and_, text, cast, or_, and_, func
|
||||
|
||||
|
||||
def has_permission(db, DocumentModel, query, filter: dict, permission: str = "read"):
|
||||
group_ids = filter.get("group_ids", [])
|
||||
user_id = filter.get("user_id")
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
conditions = []
|
||||
|
||||
# Handle read_only permission separately
|
||||
if permission == "read_only":
|
||||
# For read_only, we want items where:
|
||||
# 1. User has explicit read permission (via groups or user-level)
|
||||
# 2. BUT does NOT have write permission
|
||||
# 3. Public items are NOT considered read_only
|
||||
|
||||
read_conditions = []
|
||||
|
||||
# Group-level read permission
|
||||
if group_ids:
|
||||
group_read_conditions = []
|
||||
for gid in group_ids:
|
||||
if dialect_name == "sqlite":
|
||||
group_read_conditions.append(
|
||||
DocumentModel.access_control["read"]["group_ids"].contains(
|
||||
[gid]
|
||||
)
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
group_read_conditions.append(
|
||||
cast(
|
||||
DocumentModel.access_control["read"]["group_ids"],
|
||||
JSONB,
|
||||
).contains([gid])
|
||||
)
|
||||
|
||||
if group_read_conditions:
|
||||
read_conditions.append(or_(*group_read_conditions))
|
||||
|
||||
# Combine read conditions
|
||||
if read_conditions:
|
||||
has_read = or_(*read_conditions)
|
||||
else:
|
||||
# If no read conditions, return empty result
|
||||
return query.filter(False)
|
||||
|
||||
# Now exclude items where user has write permission
|
||||
write_exclusions = []
|
||||
|
||||
# Exclude items owned by user (they have implicit write)
|
||||
if user_id:
|
||||
write_exclusions.append(DocumentModel.user_id != user_id)
|
||||
|
||||
# Exclude items where user has explicit write permission via groups
|
||||
if group_ids:
|
||||
group_write_conditions = []
|
||||
for gid in group_ids:
|
||||
if dialect_name == "sqlite":
|
||||
group_write_conditions.append(
|
||||
DocumentModel.access_control["write"]["group_ids"].contains(
|
||||
[gid]
|
||||
)
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
group_write_conditions.append(
|
||||
cast(
|
||||
DocumentModel.access_control["write"]["group_ids"],
|
||||
JSONB,
|
||||
).contains([gid])
|
||||
)
|
||||
|
||||
if group_write_conditions:
|
||||
# User should NOT have write permission
|
||||
write_exclusions.append(~or_(*group_write_conditions))
|
||||
|
||||
# Exclude public items (items without access_control)
|
||||
write_exclusions.append(DocumentModel.access_control.isnot(None))
|
||||
write_exclusions.append(cast(DocumentModel.access_control, String) != "null")
|
||||
|
||||
# Combine: has read AND does not have write AND not public
|
||||
if write_exclusions:
|
||||
query = query.filter(and_(has_read, *write_exclusions))
|
||||
else:
|
||||
query = query.filter(has_read)
|
||||
|
||||
return query
|
||||
|
||||
# Original logic for other permissions (read, write, etc.)
|
||||
# Public access conditions
|
||||
if group_ids or user_id:
|
||||
conditions.extend(
|
||||
[
|
||||
DocumentModel.access_control.is_(None),
|
||||
cast(DocumentModel.access_control, String) == "null",
|
||||
]
|
||||
)
|
||||
|
||||
# User-level permission (owner has all permissions)
|
||||
if user_id:
|
||||
conditions.append(DocumentModel.user_id == user_id)
|
||||
|
||||
# Group-level permission
|
||||
if group_ids:
|
||||
group_conditions = []
|
||||
for gid in group_ids:
|
||||
if dialect_name == "sqlite":
|
||||
group_conditions.append(
|
||||
DocumentModel.access_control[permission]["group_ids"].contains(
|
||||
[gid]
|
||||
)
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
group_conditions.append(
|
||||
cast(
|
||||
DocumentModel.access_control[permission]["group_ids"],
|
||||
JSONB,
|
||||
).contains([gid])
|
||||
)
|
||||
conditions.append(or_(*group_conditions))
|
||||
|
||||
if conditions:
|
||||
query = query.filter(or_(*conditions))
|
||||
|
||||
return query
|
||||
Reference in New Issue
Block a user