From b00be730ac5e739219486ca7b8c9a5b897fd723f Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Tue, 20 Aug 2024 10:06:30 -0700 Subject: [PATCH] Allow users to pass node styler rules to the visualizer component PiperOrigin-RevId: 665412331 --- .../src/components/visualizer/common/types.ts | 38 ++++++++++++++++--- .../src/components/visualizer/common/utils.ts | 17 +++++++++ .../visualizer/common/visualizer_config.ts | 7 +++- .../visualizer/model_graph_visualizer.ts | 4 ++ .../visualizer/node_styler_dialog.ts | 10 +---- .../visualizer/node_styler_service.ts | 36 +++++++----------- .../components/visualizer/renderer_wrapper.ts | 18 ++------- .../components/visualizer/webgl_renderer.ts | 27 +++++++++---- 8 files changed, 97 insertions(+), 60 deletions(-) diff --git a/src/ui/src/components/visualizer/common/types.ts b/src/ui/src/components/visualizer/common/types.ts index 4db5c7d8..4ad247f0 100644 --- a/src/ui/src/components/visualizer/common/types.ts +++ b/src/ui/src/components/visualizer/common/types.ts @@ -536,12 +536,38 @@ export declare interface ShowOnEdgeItemData { selected: boolean; } +/** The ids of the node style. */ +export enum NodeStyleId { + NODE_BG_COLOR = 'node_bg_color', + NODE_TEXT_COLOR = 'node_text_color', + NODE_BORDER_COLOR = 'node_border_color', +} + /** A rule for node styler. All fields should be serializable. */ export declare interface NodeStylerRule { + /** + * Quries are connected with AND. + */ queries: NodeQuery[]; - nodeType: SearchNodeType; - // Indexed by style ids. - styles: Record; + + /** + * The type of node to match. + * + * @deprecated The new version the of rule stores node type as a query in + * `queries` above. + */ + nodeType?: SearchNodeType; + + /** + * Styles applied to the matched nodes. + * + * Indexed by style ids. + */ + styles: Partial>; + + /** + * Should set this to V2. + */ version?: NodeStylerRuleVersion; } @@ -557,8 +583,8 @@ declare interface NodeQueryBase { /** A rule width processed node styler rules. */ export interface ProcessedNodeStylerRule { queries: ProcessedNodeQuery[]; - nodeType: SearchNodeType; - styles: Record; + nodeType?: SearchNodeType; + styles: Record; } declare interface ProcessedNodeQueryBase { @@ -621,6 +647,6 @@ export enum SearchNodeType { /** Serialized style. */ export declare interface SerializedStyle { - id: string; + id: NodeStyleId; value: string; } diff --git a/src/ui/src/components/visualizer/common/utils.ts b/src/ui/src/components/visualizer/common/utils.ts index b83b4ecf..6ba4beed 100644 --- a/src/ui/src/components/visualizer/common/utils.ts +++ b/src/ui/src/components/visualizer/common/utils.ts @@ -41,6 +41,7 @@ import { NodeDataProviderRunData, NodeQuery, NodeQueryType, + NodeStyleId, NodeStylerRule, Point, ProcessedNodeQuery, @@ -933,3 +934,19 @@ export function getHighQualityPixelRatio(): number { ? 1.5 /* This makes rendering result sharper on non-retina displays */ : window.devicePixelRatio; } + +/** Get the value for the given style. */ +export function getNodeStyleValue( + rule: ProcessedNodeStylerRule | NodeStylerRule, + styleId: NodeStyleId, +): string { + const curStyle = rule.styles[styleId]; + if (curStyle) { + if (typeof curStyle === 'string') { + return curStyle; + } else { + return curStyle.value; + } + } + return ''; +} diff --git a/src/ui/src/components/visualizer/common/visualizer_config.ts b/src/ui/src/components/visualizer/common/visualizer_config.ts index 564a083f..8c5de62f 100644 --- a/src/ui/src/components/visualizer/common/visualizer_config.ts +++ b/src/ui/src/components/visualizer/common/visualizer_config.ts @@ -16,7 +16,7 @@ * ============================================================================== */ -import {RendererType} from './types'; +import {NodeStylerRule, RendererType} from './types'; /** Configs for the visualizer. */ export declare interface VisualizerConfig { @@ -50,8 +50,13 @@ export declare interface VisualizerConfig { /** Whether to keep layers with a single child. */ keepLayersWithASingleChild?: boolean; + /** The default node styler rules. */ + nodeStylerRules?: NodeStylerRule[]; + /** * Default graph renderer. + * + * @deprecated This field is no longer used. */ defaultRenderer?: RendererType; diff --git a/src/ui/src/components/visualizer/model_graph_visualizer.ts b/src/ui/src/components/visualizer/model_graph_visualizer.ts index df16e0e3..93fe6231 100644 --- a/src/ui/src/components/visualizer/model_graph_visualizer.ts +++ b/src/ui/src/components/visualizer/model_graph_visualizer.ts @@ -129,6 +129,7 @@ export class ModelGraphVisualizer implements OnInit, OnDestroy, OnChanges { private readonly threejsService: ThreejsService, private readonly uiStateService: UiStateService, private readonly nodeDataProviderExtensionService: NodeDataProviderExtensionService, + private readonly nodeStylerService: NodeStylerService, ) { effect(() => { const curUiState = this.uiStateService.curUiState(); @@ -184,6 +185,9 @@ export class ModelGraphVisualizer implements OnInit, OnDestroy, OnChanges { this.appService.config.set(this.config || {}); this.appService.addGraphCollections(this.graphCollections); this.appService.curInitialUiState.set(this.initialUiState); + if (this.config?.nodeStylerRules) { + this.nodeStylerService.rules.set(this.config.nodeStylerRules); + } // No initial ui state. Use the graph with the most node counts as the // default selected graph. diff --git a/src/ui/src/components/visualizer/node_styler_dialog.ts b/src/ui/src/components/visualizer/node_styler_dialog.ts index 959387b8..23142808 100644 --- a/src/ui/src/components/visualizer/node_styler_dialog.ts +++ b/src/ui/src/components/visualizer/node_styler_dialog.ts @@ -45,6 +45,7 @@ import { SearchMatchType, SearchNodeType, } from './common/types'; +import {getNodeStyleValue} from './common/utils'; import {ComplexQueries} from './complex_queries'; import {NodeListViewer} from './node_list_viewer'; import { @@ -205,13 +206,6 @@ export class NodeStylerDialog { this.nodeStylerService.updateStyleValue(index, style, `${n}`); } - handleNodeTypeChanged(ruleIndex: number, nodeType: string) { - this.nodeStylerService.updateNodeType( - ruleIndex, - nodeType as SearchNodeType, - ); - } - handleMoveUpRule(index: number) { this.nodeStylerService.moveUpRule(index); } @@ -233,7 +227,7 @@ export class NodeStylerDialog { } getSerializedStyleValue(rule: NodeStylerRule, style: Style): string { - return rule.styles[style.id]?.value || ''; + return getNodeStyleValue(rule, style.id); } getMatchedNodes(ruleIndex: number, paneIndex: number): ModelNode[] { diff --git a/src/ui/src/components/visualizer/node_styler_service.ts b/src/ui/src/components/visualizer/node_styler_service.ts index f49a866b..8d5e014a 100644 --- a/src/ui/src/components/visualizer/node_styler_service.ts +++ b/src/ui/src/components/visualizer/node_styler_service.ts @@ -24,6 +24,7 @@ import {ModelNode} from './common/model_graph'; import { NodeQuery, NodeQueryType, + NodeStyleId, NodeStylerRule, NodeStylerRuleVersion, NodeTypeQuery, @@ -41,7 +42,7 @@ import {LocalStorageService} from './local_storage_service'; export interface Style { type: StyleType; label: string; - id: string; + id: NodeStyleId; defaultValue: string; } @@ -51,19 +52,11 @@ export enum StyleType { NUMBER = 'NUMBER', } -/** Style ids. */ -export enum StyleId { - NODE_BG_COLOR = 'node_bg_color', - NODE_TEXT_COLOR = 'node_text_color', - NODE_BORDER_COLOR = 'node_border_color', - NODE_SCALE = 'node_scale', -} - /** Style for node background color. */ export const NODE_BG_COLOR_STYLE: Style = { type: StyleType.COLOR, label: 'Bg color', - id: StyleId.NODE_BG_COLOR, + id: NodeStyleId.NODE_BG_COLOR, defaultValue: '#ffffff', }; @@ -71,7 +64,7 @@ export const NODE_BG_COLOR_STYLE: Style = { export const NODE_BORDER_COLOR_STYLE: Style = { type: StyleType.COLOR, label: 'Border color', - id: StyleId.NODE_BORDER_COLOR, + id: NodeStyleId.NODE_BORDER_COLOR, defaultValue: '#777777', }; @@ -79,7 +72,7 @@ export const NODE_BORDER_COLOR_STYLE: Style = { export const NODE_TEXT_COLOR_STYLE: Style = { type: StyleType.COLOR, label: 'Text color', - id: StyleId.NODE_TEXT_COLOR, + id: NodeStyleId.NODE_TEXT_COLOR, defaultValue: '#041e49', }; @@ -251,18 +244,17 @@ export class NodeStylerService { }); } - updateNodeType(ruleIndex: number, nodeType: SearchNodeType) { - this.rules.update((rules) => { - const rule = rules[ruleIndex]; - rule.nodeType = nodeType; - return [...rules]; - }); - } - updateStyleValue(ruleIndex: number, style: Style, value: string) { this.rules.update((rules) => { const rule = rules[ruleIndex]; - rule.styles[style.id].value = value; + const curStyle = rule.styles[style.id]; + if (curStyle) { + if (typeof curStyle === 'string') { + rule.styles[style.id] = value; + } else { + curStyle.value = value; + } + } return [...rules]; }); } @@ -279,7 +271,7 @@ export class NodeStylerService { return rules.map((rule) => { // For older version of the rule, convert the node type to // a query. - if (rule.version == null) { + if (rule.version == null && rule.nodeType) { const nodeTypeQuery: NodeTypeQuery = { type: NodeQueryType.NODE_TYPE, nodeType: rule.nodeType, diff --git a/src/ui/src/components/visualizer/renderer_wrapper.ts b/src/ui/src/components/visualizer/renderer_wrapper.ts index baac8996..216545b0 100644 --- a/src/ui/src/components/visualizer/renderer_wrapper.ts +++ b/src/ui/src/components/visualizer/renderer_wrapper.ts @@ -26,7 +26,6 @@ import { effect, EventEmitter, Input, - OnInit, Output, ViewChild, } from '@angular/core'; @@ -43,7 +42,6 @@ import {AppService} from './app_service'; import {type ModelGraph} from './common/model_graph'; import { PopupPanelData, - RendererType, SelectedNodeInfo, SubgraphBreadcrumbItem, } from './common/types'; @@ -54,8 +52,6 @@ import {SubgraphBreadcrumbs} from './subgraph_breadcrumbs'; import {ViewOnNode} from './view_on_node'; import {WebglRenderer} from './webgl_renderer'; -const DEFAULT_RENDERER_TYPE: RendererType = RendererType.WEBGL; - /** A wrapper panel around various renderers. */ @Component({ standalone: true, @@ -79,7 +75,7 @@ const DEFAULT_RENDERER_TYPE: RendererType = RendererType.WEBGL; styleUrls: ['./renderer_wrapper.scss'], changeDetection: ChangeDetectionStrategy.OnPush, }) -export class RendererWrapper implements OnInit { +export class RendererWrapper { @Input({required: true}) modelGraph!: ModelGraph; @Input({required: true}) rendererId!: string; @Input({required: true}) paneId!: string; @@ -90,8 +86,6 @@ export class RendererWrapper implements OnInit { @ViewChild('webglRenderer') webglRenderer?: WebglRenderer; - readonly RendererType = RendererType; - readonly helpPopupSize: OverlaySizeConfig = { minWidth: 0, minHeight: 0, @@ -101,7 +95,6 @@ export class RendererWrapper implements OnInit { flattenAllLayers = computed(() => this.appService.getFlattenLayers(this.paneId), ); - rendererType = DEFAULT_RENDERER_TYPE; disableDownloadPngHelpPopup = false; transparentPngBackground = new FormControl(false); @@ -118,11 +111,6 @@ export class RendererWrapper implements OnInit { }); } - ngOnInit() { - this.rendererType = - this.appService.config()?.defaultRenderer ?? DEFAULT_RENDERER_TYPE; - } - handleOpenOnPopupClicked(data: PopupPanelData) { this.openInPopupClicked.emit(data); } @@ -205,11 +193,11 @@ export class RendererWrapper implements OnInit { } get showDownloadPng(): boolean { - return !this.inPopup && this.rendererType === RendererType.WEBGL; + return !this.inPopup; } get showSnapshotManager(): boolean { - return !this.inPopup && this.rendererType === RendererType.WEBGL; + return !this.inPopup; } get showSubgraphBreadcrumbs(): boolean { diff --git a/src/ui/src/components/visualizer/webgl_renderer.ts b/src/ui/src/components/visualizer/webgl_renderer.ts index a9361a88..d2d598ac 100644 --- a/src/ui/src/components/visualizer/webgl_renderer.ts +++ b/src/ui/src/components/visualizer/webgl_renderer.ts @@ -64,6 +64,7 @@ import { FontWeight, NodeDataProviderResultProcessedData, NodeDataProviderRunData, + NodeStyleId, NodeStylerRule, Point, PopupPanelData, @@ -79,6 +80,7 @@ import { import { genUid, getHighQualityPixelRatio, + getNodeStyleValue, hasNonEmptyQueries, IS_MAC, isGroupNode, @@ -97,7 +99,7 @@ import { import {DragArea} from './drag_area'; import {genIoTreeData, IoTree} from './io_tree'; import {NodeDataProviderExtensionService} from './node_data_provider_extension_service'; -import {NodeStylerService, StyleId} from './node_styler_service'; +import {NodeStylerService} from './node_styler_service'; import {SplitPaneService} from './split_pane_service'; import {SubgraphSelectionService} from './subgraph_selection_service'; import {ThreejsService} from './threejs_service'; @@ -1767,17 +1769,24 @@ export class WebglRenderer implements OnInit, OnDestroy { // Node styler. for (const rule of this.curProcessedNodeStylerRules) { if (matchNodeForQueries(node, rule.queries, this.curModelGraph)) { - const nodeStylerBgColor = - rule.styles[StyleId.NODE_BG_COLOR]?.value || ''; + const nodeStylerBgColor = getNodeStyleValue( + rule, + NodeStyleId.NODE_BG_COLOR, + ); if (nodeStylerBgColor !== '') { bgColor = new THREE.Color(nodeStylerBgColor); } - const nodeBorderColor = - rule.styles[StyleId.NODE_BORDER_COLOR]?.value || ''; + const nodeBorderColor = getNodeStyleValue( + rule, + NodeStyleId.NODE_BORDER_COLOR, + ); if (nodeBorderColor !== '') { borderColor = new THREE.Color(nodeBorderColor); } - const textColor = rule.styles[StyleId.NODE_TEXT_COLOR]?.value || ''; + const textColor = getNodeStyleValue( + rule, + NodeStyleId.NODE_TEXT_COLOR, + ); if (textColor !== '') { groupNodeIconColor = new THREE.Color(textColor); } @@ -2019,8 +2028,10 @@ export class WebglRenderer implements OnInit, OnDestroy { // Node styler. for (const rule of this.curProcessedNodeStylerRules) { if (matchNodeForQueries(node, rule.queries, this.curModelGraph)) { - const nodeStylerTextColor = - rule.styles[StyleId.NODE_TEXT_COLOR]?.value || ''; + const nodeStylerTextColor = getNodeStyleValue( + rule, + NodeStyleId.NODE_TEXT_COLOR, + ); if (nodeStylerTextColor !== '') { color = new THREE.Color(nodeStylerTextColor); }