Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions invokeai/app/api/routers/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,15 @@ async def get_workflow_thumbnail(
raise HTTPException(status_code=404)


@workflows_router.get("/tags", operation_id="get_all_tags")
async def get_all_tags(
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
) -> list[str]:
"""Gets all unique tags from workflows"""

return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)


@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
async def get_counts_by_tag(
tags: list[str] = Query(description="The tags to get counts for"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,11 @@ def counts_by_tag(
def update_opened_at(self, workflow_id: str) -> None:
"""Open a workflow."""
pass

@abstractmethod
def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
) -> list[str]:
"""Gets all unique tags from workflows."""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,48 @@ def update_opened_at(self, workflow_id: str) -> None:
(workflow_id,),
)

def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
) -> list[str]:
with self._db.transaction() as cursor:
conditions: list[str] = []
params: list[str] = []

# Only get workflows that have tags
conditions.append("tags IS NOT NULL AND tags != ''")

if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
conditions.append(f"category IN ({placeholders})")
params.extend([category.value for category in categories])

stmt = """--sql
SELECT DISTINCT tags
FROM workflow_library
"""

if conditions:
stmt += " WHERE " + " AND ".join(conditions)

cursor.execute(stmt, params)
rows = cursor.fetchall()

# Parse comma-separated tags and collect unique tags
all_tags: set[str] = set()

for row in rows:
tags_value = row[0]
if tags_value and isinstance(tags_value, str):
# Tags are stored as comma-separated string
for tag in tags_value.split(","):
tag_stripped = tag.strip()
if tag_stripped:
all_tags.add(tag_stripped)

return sorted(all_tags)

def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiStarFill } from 'react-icons/pi';
import { useDispatch } from 'react-redux';
import { useGetCountsByTagQuery } from 'services/api/endpoints/workflows';
import { useGetAllTagsQuery, useGetCountsByTagQuery } from 'services/api/endpoints/workflows';

export const WorkflowLibrarySideNav = () => {
const { t } = useTranslation();
Expand All @@ -40,11 +40,11 @@ export const WorkflowLibrarySideNav = () => {
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
<Flex flexDir="column" w="full" pb={2} gap={2}>
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
<WorkflowLibraryViewButton view="yours">{t('workflows.yourWorkflows')}</WorkflowLibraryViewButton>
<YourWorkflowsButton />
</Flex>
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
<BrowseWorkflowsButton />
<DefaultsViewCheckboxesCollapsible />
<TagCheckboxesCollapsible />
</Flex>
<Spacer />
<NewWorkflowButton />
Expand All @@ -53,6 +53,40 @@ export const WorkflowLibrarySideNav = () => {
);
};

const YourWorkflowsButton = memo(() => {
const { t } = useTranslation();
const view = useAppSelector(selectWorkflowLibraryView);
const dispatch = useAppDispatch();
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
const resetTags = useCallback(() => {
dispatch(workflowLibraryTagsReset());
}, [dispatch]);

if (view === 'yours' && selectedTags.length > 0) {
return (
<ButtonGroup>
<WorkflowLibraryViewButton view="yours" w="auto">
{t('workflows.yourWorkflows')}
</WorkflowLibraryViewButton>
<Tooltip label={t('workflows.deselectAll')}>
<IconButton
onClick={resetTags}
size="md"
aria-label={t('workflows.deselectAll')}
icon={<PiArrowCounterClockwiseBold size={12} />}
variant="ghost"
bg="base.700"
color="base.50"
/>
</Tooltip>
</ButtonGroup>
);
}

return <WorkflowLibraryViewButton view="yours">{t('workflows.yourWorkflows')}</WorkflowLibraryViewButton>;
});
YourWorkflowsButton.displayName = 'YourWorkflowsButton';

const BrowseWorkflowsButton = memo(() => {
const { t } = useTranslation();
const view = useAppSelector(selectWorkflowLibraryView);
Expand Down Expand Up @@ -89,31 +123,114 @@ BrowseWorkflowsButton.displayName = 'BrowseWorkflowsButton';

const overlayscrollbarsOptions = getOverlayScrollbarsParams({ visibility: 'visible' }).options;

const DefaultsViewCheckboxesCollapsible = memo(() => {
const TagCheckboxesCollapsible = memo(() => {
const view = useAppSelector(selectWorkflowLibraryView);

return (
<Collapse in={view === 'defaults'}>
<Collapse in={view === 'defaults' || view === 'yours'}>
<Flex flexDir="column" gap={2} pl={4} py={2} overflow="hidden" h="100%" minH={0}>
<OverlayScrollbarsComponent style={overlayScrollbarsStyles} options={overlayscrollbarsOptions}>
<Flex flexDir="column" gap={2} overflow="auto">
{WORKFLOW_LIBRARY_TAG_CATEGORIES.map((tagCategory) => (
<TagCategory key={tagCategory.categoryTKey} tagCategory={tagCategory} />
))}
{view === 'yours' ? <DynamicTagsList /> : <StaticTagCategories />}
</Flex>
</OverlayScrollbarsComponent>
</Flex>
</Collapse>
);
});
DefaultsViewCheckboxesCollapsible.displayName = 'DefaultsViewCheckboxes';
TagCheckboxesCollapsible.displayName = 'TagCheckboxesCollapsible';

const tagCountQueryArg = {
tags: WORKFLOW_LIBRARY_TAGS.map((tag) => tag.label),
categories: ['default'],
} satisfies Parameters<typeof useGetCountsByTagQuery>[0];
const StaticTagCategories = memo(() => {
return (
<>
{WORKFLOW_LIBRARY_TAG_CATEGORIES.map((tagCategory) => (
<TagCategory key={tagCategory.categoryTKey} tagCategory={tagCategory} />
))}
</>
);
});
StaticTagCategories.displayName = 'StaticTagCategories';

const DynamicTagsList = memo(() => {
const { t } = useTranslation();
const { data: tags, isLoading } = useGetAllTagsQuery({ categories: ['user'] });

if (isLoading) {
return <Text color="base.400">{t('common.loading')}</Text>;
}

if (!tags || tags.length === 0) {
return null;
}

return (
<Flex flexDir="column" gap={2}>
{tags.map((tag) => (
<DynamicTagCheckbox key={tag} tag={tag} />
))}
</Flex>
);
});
DynamicTagsList.displayName = 'DynamicTagsList';

const DynamicTagCheckbox = memo(({ tag }: { tag: string }) => {
const dispatch = useAppDispatch();
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
const isChecked = selectedTags.includes(tag);
const count = useDynamicTagCount(tag);

const onChange = useCallback(() => {
dispatch(workflowLibraryTagToggled(tag));
}, [dispatch, tag]);

if (count === 0) {
return null;
}

return (
<Flex alignItems="center" gap={2}>
<Checkbox isChecked={isChecked} onChange={onChange} flexShrink={0} />
<Text>{`${tag} (${count})`}</Text>
</Flex>
);
});
DynamicTagCheckbox.displayName = 'DynamicTagCheckbox';

const useDynamicTagCount = (tag: string) => {
const queryArg = useMemo(
() => ({
tags: [tag],
categories: ['user'] as ('user' | 'default')[],
}),
[tag]
);

const queryOptions = useMemo(
() => ({
selectFromResult: ({ data }: { data?: Record<string, number> }) => ({
count: data?.[tag] ?? 0,
}),
}),
[tag]
);

const { count } = useGetCountsByTagQuery(queryArg, queryOptions);
return count;
};

const useTagCountQueryArg = () => {
const view = useAppSelector(selectWorkflowLibraryView);
return useMemo(
() => ({
tags: WORKFLOW_LIBRARY_TAGS.map((tag) => tag.label),
categories: view === 'yours' ? ['user'] : ['default'],
}),
[view]
) satisfies Parameters<typeof useGetCountsByTagQuery>[0];
};

const useCountForIndividualTag = (tag: string) => {
const tagCountQueryArg = useTagCountQueryArg();
const queryOptions = useMemo(
() =>
({
Expand All @@ -130,6 +247,7 @@ const useCountForIndividualTag = (tag: string) => {
};

const useCountForTagCategory = (tagCategory: WorkflowTagCategory) => {
const tagCountQueryArg = useTagCountQueryArg();
const queryOptions = useMemo(
() =>
({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ const useInfiniteQueryAry = () => {
direction,
categories: getCategories(view),
query: debouncedSearchTerm,
tags: view === 'defaults' ? selectedTags : [],
tags: view === 'defaults' || view === 'yours' ? selectedTags : [],
has_been_opened: getHasBeenOpened(view),
} satisfies Parameters<typeof useListWorkflowsInfiniteInfiniteQuery>[0];
}, [orderBy, direction, view, debouncedSearchTerm, selectedTags]);
Expand Down
10 changes: 10 additions & 0 deletions invokeai/frontend/web/src/services/api/endpoints/workflows.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export const workflowsApi = api.injectEndpoints({
// Because this may change the order of the list, we need to invalidate the whole list
{ type: 'Workflow', id: LIST_TAG },
{ type: 'Workflow', id: workflow_id },
'WorkflowTags',
'WorkflowTagCounts',
'WorkflowCategoryCounts',
],
Expand All @@ -46,6 +47,7 @@ export const workflowsApi = api.injectEndpoints({
invalidatesTags: [
// Because this may change the order of the list, we need to invalidate the whole list
{ type: 'Workflow', id: LIST_TAG },
'WorkflowTags',
'WorkflowTagCounts',
'WorkflowCategoryCounts',
],
Expand All @@ -61,10 +63,17 @@ export const workflowsApi = api.injectEndpoints({
}),
invalidatesTags: (response, error, workflow) => [
{ type: 'Workflow', id: workflow.id },
'WorkflowTags',
'WorkflowTagCounts',
'WorkflowCategoryCounts',
],
}),
getAllTags: build.query<string[], { categories?: ('user' | 'default')[] } | void>({
query: (params) => ({
url: `${buildWorkflowsUrl('tags')}${params ? `?${queryString.stringify(params, { arrayFormat: 'none' })}` : ''}`,
}),
providesTags: ['WorkflowTags'],
}),
getCountsByTag: build.query<
paths['/api/v1/workflows/counts_by_tag']['get']['responses']['200']['content']['application/json'],
NonNullable<paths['/api/v1/workflows/counts_by_tag']['get']['parameters']['query']>
Expand Down Expand Up @@ -153,6 +162,7 @@ export const workflowsApi = api.injectEndpoints({

export const {
useUpdateOpenedAtMutation,
useGetAllTagsQuery,
useGetCountsByTagQuery,
useGetCountsByCategoryQuery,
useLazyGetWorkflowQuery,
Expand Down
1 change: 1 addition & 0 deletions invokeai/frontend/web/src/services/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const tagTypes = [
'LoRAModel',
'SDXLRefinerModel',
'Workflow',
'WorkflowTags',
'WorkflowTagCounts',
'WorkflowCategoryCounts',
'StylePreset',
Expand Down
52 changes: 52 additions & 0 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1712,6 +1712,26 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/workflows/tags": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get All Tags
* @description Gets all unique tags from workflows
*/
get: operations["get_all_tags"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/workflows/counts_by_tag": {
parameters: {
query?: never;
Expand Down Expand Up @@ -28231,6 +28251,38 @@ export interface operations {
};
};
};
get_all_tags: {
parameters: {
query?: {
/** @description The categories to include */
categories?: components["schemas"]["WorkflowCategory"][] | null;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": string[];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_counts_by_tag: {
parameters: {
query: {
Expand Down