enh/refac: kb pagination

This commit is contained in:
Timothy Jaeryang Baek
2025-12-10 23:19:19 -05:00
parent 3ed1df2e53
commit ceae3d48e6
18 changed files with 1086 additions and 526 deletions

View File

@@ -104,6 +104,11 @@ class FileUpdateForm(BaseModel):
meta: Optional[dict] = None
class FileListResponse(BaseModel):
items: list[FileModel]
total: int
class FilesTable:
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db:

View File

@@ -5,6 +5,7 @@ from typing import Optional
import uuid
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import (
@@ -30,6 +31,8 @@ from sqlalchemy import (
)
from open_webui.utils.access_control import has_access
from open_webui.utils.db.access_control import has_permission
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -145,6 +148,11 @@ class FileUserResponse(FileModelResponse):
user: Optional[UserResponse] = None
class KnowledgeListResponse(BaseModel):
items: list[KnowledgeUserModel]
total: int
class KnowledgeFileListResponse(BaseModel):
items: list[FileUserResponse]
total: int
@@ -177,12 +185,13 @@ class KnowledgeTable:
except Exception:
return None
def get_knowledge_bases(self) -> list[KnowledgeUserModel]:
def get_knowledge_bases(
self, skip: int = 0, limit: int = 30
) -> list[KnowledgeUserModel]:
with get_db() as db:
all_knowledge = (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
)
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
users = Users.get_users_by_user_ids(user_ids) if user_ids else []
@@ -201,6 +210,126 @@ class KnowledgeTable:
)
return knowledge_bases
def search_knowledge_bases(
self, user_id: str, filter: dict, skip: int = 0, limit: int = 30
) -> KnowledgeListResponse:
try:
with get_db() as db:
query = db.query(Knowledge, User).outerjoin(
User, User.id == Knowledge.user_id
)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Knowledge.name.ilike(f"%{query_key}%"),
Knowledge.description.ilike(f"%{query_key}%"),
)
)
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Knowledge.user_id == user_id)
elif view_option == "shared":
query = query.filter(Knowledge.user_id != user_id)
query = has_permission(db, Knowledge, query, filter)
query = query.order_by(Knowledge.updated_at.desc())
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
knowledge_bases = []
for knowledge_base, user in items:
knowledge_bases.append(
KnowledgeUserModel.model_validate(
{
**KnowledgeModel.model_validate(
knowledge_base
).model_dump(),
"user": (
UserModel.model_validate(user).model_dump()
if user
else None
),
}
)
)
return KnowledgeListResponse(items=knowledge_bases, total=total)
except Exception as e:
print(e)
return KnowledgeListResponse(items=[], total=0)
def search_knowledge_files(
self, filter: dict, skip: int = 0, limit: int = 30
) -> KnowledgeFileListResponse:
"""
Scalable version: search files across all knowledge bases the user has
READ access to, without loading all KBs or using large IN() lists.
"""
try:
with get_db() as db:
# Base query: join Knowledge → KnowledgeFile → File
query = (
db.query(File, User)
.join(KnowledgeFile, File.id == KnowledgeFile.file_id)
.join(Knowledge, KnowledgeFile.knowledge_id == Knowledge.id)
.outerjoin(User, User.id == KnowledgeFile.user_id)
)
# Apply access-control directly to the joined query
# This makes the database handle filtering, even with 10k+ KBs
query = has_permission(db, Knowledge, query, filter)
# Apply filename search
if filter:
q = filter.get("query")
if q:
query = query.filter(File.filename.ilike(f"%{q}%"))
# Order by file changes
query = query.order_by(File.updated_at.desc())
# Count before pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
rows = query.all()
items = []
for file, user in rows:
items.append(
FileUserResponse(
**FileModel.model_validate(file).model_dump(),
user=(
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
)
)
return KnowledgeFileListResponse(items=items, total=total)
except Exception as e:
print("search_knowledge_files error:", e)
return KnowledgeFileListResponse(items=[], total=0)
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
knowledge = self.get_knowledge_by_id(id)
if not knowledge:

View File

@@ -39,7 +39,6 @@ from open_webui.models.knowledge import Knowledges
from open_webui.models.groups import Groups
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
from open_webui.routers.retrieval import ProcessFileForm, process_file
from open_webui.routers.audio import transcribe

View File

@@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
from fastapi.concurrency import run_in_threadpool
import logging
from open_webui.models.groups import Groups
from open_webui.models.knowledge import (
KnowledgeFileListResponse,
Knowledges,
@@ -40,53 +41,115 @@ router = APIRouter()
# getKnowledgeBases
############################
PAGE_ITEM_COUNT = 30
class KnowledgeAccessResponse(KnowledgeUserResponse):
write_access: Optional[bool] = False
@router.get("/", response_model=list[KnowledgeAccessResponse])
async def get_knowledge(user=Depends(get_verified_user)):
# Return knowledge bases with read access
knowledge_bases = []
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
knowledge_bases = Knowledges.get_knowledge_bases()
else:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
return [
KnowledgeAccessResponse(
**knowledge_base.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge_base.id),
write_access=(
user.id == knowledge_base.user_id
or has_access(user.id, "write", knowledge_base.access_control)
),
)
for knowledge_base in knowledge_bases
]
class KnowledgeAccessListResponse(BaseModel):
items: list[KnowledgeAccessResponse]
total: int
@router.get("/list", response_model=list[KnowledgeAccessResponse])
async def get_knowledge_list(user=Depends(get_verified_user)):
# Return knowledge bases with write access
knowledge_bases = []
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
knowledge_bases = Knowledges.get_knowledge_bases()
else:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read")
@router.get("/", response_model=KnowledgeAccessListResponse)
async def get_knowledge_bases(page: Optional[int] = 1, user=Depends(get_verified_user)):
page = max(page, 1)
limit = PAGE_ITEM_COUNT
skip = (page - 1) * limit
return [
KnowledgeAccessResponse(
**knowledge_base.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge_base.id),
write_access=(
user.id == knowledge_base.user_id
or has_access(user.id, "write", knowledge_base.access_control)
),
)
for knowledge_base in knowledge_bases
]
filter = {}
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
result = Knowledges.search_knowledge_bases(
user.id, filter=filter, skip=skip, limit=limit
)
return KnowledgeAccessListResponse(
items=[
KnowledgeAccessResponse(
**knowledge_base.model_dump(),
write_access=(
user.id == knowledge_base.user_id
or has_access(user.id, "write", knowledge_base.access_control)
),
)
for knowledge_base in result.items
],
total=result.total,
)
@router.get("/search", response_model=KnowledgeAccessListResponse)
async def search_knowledge_bases(
query: Optional[str] = None,
view_option: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
):
page = max(page, 1)
limit = PAGE_ITEM_COUNT
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if view_option:
filter["view_option"] = view_option
if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL:
groups = Groups.get_groups_by_member_id(user.id)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
result = Knowledges.search_knowledge_bases(
user.id, filter=filter, skip=skip, limit=limit
)
return KnowledgeAccessListResponse(
items=[
KnowledgeAccessResponse(
**knowledge_base.model_dump(),
write_access=(
user.id == knowledge_base.user_id
or has_access(user.id, "write", knowledge_base.access_control)
),
)
for knowledge_base in result.items
],
total=result.total,
)
@router.get("/search/files", response_model=KnowledgeFileListResponse)
async def search_knowledge_files(
query: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
):
page = max(page, 1)
limit = PAGE_ITEM_COUNT
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
groups = Groups.get_groups_by_member_id(user.id)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
return Knowledges.search_knowledge_files(filter=filter, skip=skip, limit=limit)
############################
@@ -198,7 +261,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
class KnowledgeFilesResponse(KnowledgeResponse):
files: list[FileMetadataResponse]
files: Optional[list[FileMetadataResponse]] = None
write_access: Optional[bool] = False
@@ -215,7 +278,6 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id),
write_access=(
user.id == knowledge.user_id
or has_access(user.id, "write", knowledge.access_control)

View File

@@ -0,0 +1,130 @@
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import or_, func, select, and_, text, cast, or_, and_, func
def has_permission(db, DocumentModel, query, filter: dict, permission: str = "read"):
group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id")
dialect_name = db.bind.dialect.name
conditions = []
# Handle read_only permission separately
if permission == "read_only":
# For read_only, we want items where:
# 1. User has explicit read permission (via groups or user-level)
# 2. BUT does NOT have write permission
# 3. Public items are NOT considered read_only
read_conditions = []
# Group-level read permission
if group_ids:
group_read_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_read_conditions.append(
DocumentModel.access_control["read"]["group_ids"].contains(
[gid]
)
)
elif dialect_name == "postgresql":
group_read_conditions.append(
cast(
DocumentModel.access_control["read"]["group_ids"],
JSONB,
).contains([gid])
)
if group_read_conditions:
read_conditions.append(or_(*group_read_conditions))
# Combine read conditions
if read_conditions:
has_read = or_(*read_conditions)
else:
# If no read conditions, return empty result
return query.filter(False)
# Now exclude items where user has write permission
write_exclusions = []
# Exclude items owned by user (they have implicit write)
if user_id:
write_exclusions.append(DocumentModel.user_id != user_id)
# Exclude items where user has explicit write permission via groups
if group_ids:
group_write_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_write_conditions.append(
DocumentModel.access_control["write"]["group_ids"].contains(
[gid]
)
)
elif dialect_name == "postgresql":
group_write_conditions.append(
cast(
DocumentModel.access_control["write"]["group_ids"],
JSONB,
).contains([gid])
)
if group_write_conditions:
# User should NOT have write permission
write_exclusions.append(~or_(*group_write_conditions))
# Exclude public items (items without access_control)
write_exclusions.append(DocumentModel.access_control.isnot(None))
write_exclusions.append(cast(DocumentModel.access_control, String) != "null")
# Combine: has read AND does not have write AND not public
if write_exclusions:
query = query.filter(and_(has_read, *write_exclusions))
else:
query = query.filter(has_read)
return query
# Original logic for other permissions (read, write, etc.)
# Public access conditions
if group_ids or user_id:
conditions.extend(
[
DocumentModel.access_control.is_(None),
cast(DocumentModel.access_control, String) == "null",
]
)
# User-level permission (owner has all permissions)
if user_id:
conditions.append(DocumentModel.user_id == user_id)
# Group-level permission
if group_ids:
group_conditions = []
for gid in group_ids:
if dialect_name == "sqlite":
group_conditions.append(
DocumentModel.access_control[permission]["group_ids"].contains(
[gid]
)
)
elif dialect_name == "postgresql":
group_conditions.append(
cast(
DocumentModel.access_control[permission]["group_ids"],
JSONB,
).contains([gid])
)
conditions.append(or_(*group_conditions))
if conditions:
query = query.filter(or_(*conditions))
return query