BastionSSO/user_manage_api/app/services/user_group_service.py

206 lines
8.6 KiB
Python

from pathlib import PurePosixPath
from typing import List, Optional
from app.core.errors import ApiError
from app.core.models import GroupSummary, UserCreateRequest, UserEnvironmentBatchResult, UserEnvironmentFailure, UserSummary
from app.providers.base import SystemProvider
class UserGroupService:
def __init__(
self,
provider: SystemProvider,
home_base_dir: str,
link_home_dir: Optional[str] = None,
default_shell: str = "/bin/bash",
hidden_users: Optional[List[str]] = None,
hidden_groups: Optional[List[str]] = None,
whitelist_users: Optional[List[str]] = None,
whitelist_groups: Optional[List[str]] = None,
locked_users: Optional[List[str]] = None,
user_uid_min: Optional[int] = None,
user_uid_max: Optional[int] = None,
group_gid_min: Optional[int] = None,
group_gid_max: Optional[int] = None,
):
self.provider = provider
self.home_base_dir = PurePosixPath(home_base_dir)
self.link_home_base_dir = PurePosixPath(link_home_dir) if link_home_dir else None
self.default_shell = default_shell
self.hidden_users = set(hidden_users or [])
self.hidden_groups = set(hidden_groups or [])
self.whitelist_users = set(whitelist_users or [])
self.whitelist_groups = set(whitelist_groups or [])
self.locked_users = set(locked_users or [])
self.user_uid_min = user_uid_min
self.user_uid_max = user_uid_max
self.group_gid_min = group_gid_min
self.group_gid_max = group_gid_max
def _ensure_user_visible(self, username: str) -> None:
user = self.provider.get_user(username)
if not self._is_user_visible(user):
raise ApiError(404, "not_found", "user not found")
def _ensure_groups_visible(self, groups: List[str]) -> None:
for groupname in groups:
self._ensure_group_visible(groupname)
def _ensure_group_visible(self, groupname: str) -> None:
group = self.provider.get_group(groupname)
if not self._is_group_visible(group):
raise ApiError(404, "not_found", "group not found")
def _ensure_user_name_allowed(self, username: str) -> None:
if username not in self.whitelist_users and username in self.hidden_users:
raise ApiError(404, "not_found", "user not found")
def _ensure_user_unlocked(self, username: str) -> None:
if username in self.locked_users:
raise ApiError(423, "user_locked", "user is locked and cannot be modified")
def _ensure_group_name_allowed(self, groupname: str) -> None:
if groupname not in self.whitelist_groups and groupname in self.hidden_groups:
raise ApiError(404, "not_found", "group not found")
def _is_uid_in_range(self, uid: int) -> bool:
if self.user_uid_min is not None and uid < self.user_uid_min:
return False
if self.user_uid_max is not None and uid > self.user_uid_max:
return False
return True
def _is_gid_in_range(self, gid: int) -> bool:
if self.group_gid_min is not None and gid < self.group_gid_min:
return False
if self.group_gid_max is not None and gid > self.group_gid_max:
return False
return True
def _is_user_visible(self, user: UserSummary) -> bool:
if user.username in self.whitelist_users:
return True
if user.username in self.hidden_users:
return False
return self._is_uid_in_range(user.uid)
def _is_group_visible(self, group: GroupSummary) -> bool:
if group.groupname in self.whitelist_groups:
return True
if group.groupname in self.hidden_groups:
return False
return self._is_gid_in_range(group.gid)
def _resolve_home_dir(self, username: str) -> str:
return str(self.home_base_dir / username)
def _resolve_linked_home_dir(self, username: str) -> Optional[str]:
if self.link_home_base_dir is None:
return None
return str(self.link_home_base_dir / username)
def create_user(self, payload: UserCreateRequest) -> None:
self._ensure_user_name_allowed(payload.username)
self._ensure_user_unlocked(payload.username)
if payload.primary_group is not None:
self._ensure_group_visible(payload.primary_group)
self._ensure_groups_visible(payload.groups)
home_dir = self._resolve_home_dir(payload.username)
linked_home_dir = self._resolve_linked_home_dir(payload.username)
self.provider.create_user(
username=payload.username,
password_hash=payload.password_hash,
home_dir=home_dir,
linked_home_dir=linked_home_dir,
shell=self.default_shell,
primary_group=payload.primary_group,
groups=payload.groups,
)
if payload.default_environment_variables.strip() != "":
self.provider.write_default_user_environment(payload.username, payload.default_environment_variables)
def delete_user(self, username: str) -> None:
self._ensure_user_visible(username)
self._ensure_user_unlocked(username)
self.provider.delete_user(username)
def change_user_password(self, username: str, password_hash: str) -> None:
self._ensure_user_visible(username)
self._ensure_user_unlocked(username)
self.provider.change_user_password(username, password_hash)
def list_users(self) -> List[UserSummary]:
return [user for user in self.provider.list_users() if self._is_user_visible(user)]
def get_user(self, username: str) -> UserSummary:
user = self.provider.get_user(username)
if not self._is_user_visible(user):
raise ApiError(404, "not_found", "user not found")
return user
def create_group(self, groupname: str) -> None:
self._ensure_group_name_allowed(groupname)
self.provider.create_group(groupname)
def delete_group(self, groupname: str) -> None:
self._ensure_group_visible(groupname)
group = self.provider.get_group(groupname)
if group.members:
raise ApiError(422, "precondition_failed", "Group has members and cannot be deleted.")
self.provider.delete_group(groupname)
def list_groups(self) -> List[GroupSummary]:
return [group for group in self.provider.list_groups() if self._is_group_visible(group)]
def get_group(self, groupname: str) -> GroupSummary:
group = self.provider.get_group(groupname)
if not self._is_group_visible(group):
raise ApiError(404, "not_found", "group not found")
return group
def add_user_groups(self, username: str, groups: List[str], replace: bool) -> None:
self._ensure_user_visible(username)
self._ensure_user_unlocked(username)
self._ensure_groups_visible(groups)
self.provider.add_user_groups(username, groups, replace)
def remove_user_groups(self, username: str, groups: List[str]) -> None:
self._ensure_user_visible(username)
self._ensure_user_unlocked(username)
self._ensure_groups_visible(groups)
self.provider.remove_user_groups(username, groups)
def get_user_groups(self, username: str) -> List[str]:
self._ensure_user_visible(username)
return [group for group in self.provider.get_user_groups(username) if self._is_group_visible(self.provider.get_group(group))]
def get_user_environment(self, username: str) -> str:
self._ensure_user_visible(username)
return self.provider.read_user_environment(username)
def set_user_environment(self, username: str, content: str) -> None:
self._ensure_user_visible(username)
self._ensure_user_unlocked(username)
self.provider.write_managed_user_environment(username, content)
def set_all_user_environments(self, content: str) -> UserEnvironmentBatchResult:
updated = []
failed = []
for user in self.list_users():
if user.username in self.locked_users:
continue
try:
self.provider.write_managed_user_environment(user.username, content)
updated.append(user.username)
except ApiError as exception:
failed.append(UserEnvironmentFailure(username=user.username, code=exception.code, message=exception.message))
return UserEnvironmentBatchResult(
message="User environments updated.",
updated_users=updated,
failed_users=failed,
updated_count=len(updated),
failed_count=len(failed),
)