diff --git a/src/ui/src/components/visualizer/common/input_graph.ts b/src/ui/src/components/visualizer/common/input_graph.ts index 36523bfc..a4c080ac 100644 --- a/src/ui/src/components/visualizer/common/input_graph.ts +++ b/src/ui/src/components/visualizer/common/input_graph.ts @@ -19,6 +19,7 @@ import { GraphNodeConfig, GraphNodeStyle, + GroupNodeAttributes, IncomingEdge, KeyValueList, MetadataItem, @@ -68,6 +69,13 @@ export declare interface Graph { /** A list of nodes in the graph. */ nodes: GraphNode[]; + /** + * Attributes for group nodes. + * + * It is displayed in the side panel when the group is selected. + */ + groupNodeAttributes?: GroupNodeAttributes; + ////////////////////////////////////////////////////////////////////////////// // The following fields are set by model explorer. Users don't need to set // them. diff --git a/src/ui/src/components/visualizer/common/model_graph.ts b/src/ui/src/components/visualizer/common/model_graph.ts index b3c82482..1af31451 100644 --- a/src/ui/src/components/visualizer/common/model_graph.ts +++ b/src/ui/src/components/visualizer/common/model_graph.ts @@ -19,6 +19,7 @@ import { GraphNodeConfig, GraphNodeStyle, + GroupNodeAttributes, IncomingEdge, KeyValuePairs, OutgoingEdge, @@ -45,6 +46,9 @@ export declare interface ModelGraph { /** All nodes in the model graph. */ nodes: ModelNode[]; + /** Attributes for group nodes. */ + groupNodeAttributes?: GroupNodeAttributes; + /** Ids of all group nodes that are artificially created. */ artificialGroupNodeIds?: string[]; diff --git a/src/ui/src/components/visualizer/common/types.ts b/src/ui/src/components/visualizer/common/types.ts index 1888f4dc..4db5c7d8 100644 --- a/src/ui/src/components/visualizer/common/types.ts +++ b/src/ui/src/components/visualizer/common/types.ts @@ -86,6 +86,20 @@ export declare interface Rect { height: number; } +/** Attributes for group nodes. */ +export declare interface GroupNodeAttributes { + /** + * From group's namespace to its attribuets (key-value pairs). + * + * Use empty group namespace for the model-level attributes (i.e. shown in + * side panel when no node is selected). + */ + [namespaceName: string]: Record; +} + +/** A single attribute item for group node. */ +export type GroupNodeAttributeItem = string; + /** The style of the op node. */ export declare interface GraphNodeStyle { /** @@ -475,6 +489,7 @@ export enum ShowOnNodeItemType { OP_OUTPUTS = 'Op node outputs', LAYER_NODE_CHILDREN_COUNT = 'Layer node children count', LAYER_NODE_DESCENDANTS_COUNT = 'Layer node descendants count', + LAYER_NODE_ATTRS = 'Layer node attributes', } /** Item types to be shown on edge. */ diff --git a/src/ui/src/components/visualizer/common/utils.ts b/src/ui/src/components/visualizer/common/utils.ts index 9acf9341..b83b4ecf 100644 --- a/src/ui/src/components/visualizer/common/utils.ts +++ b/src/ui/src/components/visualizer/common/utils.ts @@ -398,7 +398,7 @@ export function getOpNodeAttrsKeyValuePairsForAttrsTable( processedValue = value.replace(/\s/gm, ''); } else { // For other attributes, only remove newline chars. - processedValue = value.replace(/(\r\n|\n|\r)/gm, ''); + processedValue = value.replace(/(\r\n|\n|\r)/gm, ' '); } keyValuePairs.push({ key, @@ -409,6 +409,37 @@ export function getOpNodeAttrsKeyValuePairsForAttrsTable( return keyValuePairs; } +/** + * Gets the key value pairs for the given group node's attrs for attrs table. + */ +export function getGroupNodeAttrsKeyValuePairsForAttrsTable( + node: GroupNode, + modelGraph: ModelGraph, + filterRegex = '', +) { + const attrs = + modelGraph.groupNodeAttributes?.[node.id.replace('___group___', '')] || {}; + const keyValuePairs: KeyValueList = []; + const regex = new RegExp(filterRegex, 'i'); + for (const attrId of Object.keys(attrs)) { + const key = attrId; + const value = attrs[attrId]; + const matchTargets = [`${key}:${value}`, `${key}=${value}`]; + if ( + filterRegex.trim() === '' || + matchTargets.some((matchTarget) => regex.test(matchTarget)) + ) { + // Remove new line chars and spaces. + const processedValue = value.replace(/(\r\n|\n|\r)/gm, ' '); + keyValuePairs.push({ + key, + value: processedValue, + }); + } + } + return keyValuePairs; +} + /** Gets the key value pairs for the givn node's input for attrs table. */ export function getOpNodeInputsKeyValuePairsForAttrsTable( node: OpNode, @@ -591,7 +622,7 @@ export function getRegexMatchesForNode( } // Attribute. if (shouldMatchTypes.has(SearchMatchType.ATTRIBUTE)) { - const attrs = getAttributesFromNode(node); + const attrs = getAttributesFromNode(node, modelGraph); for (const attrId of Object.keys(attrs)) { const value = attrs[attrId]; const text1 = `${attrId}:${value}`; @@ -724,7 +755,10 @@ export function getRegexMatchesForNode( } /** Gets the attributes from the given node. */ -export function getAttributesFromNode(node: ModelNode): KeyValuePairs { +export function getAttributesFromNode( + node: ModelNode, + modelGraph: ModelGraph, +): KeyValuePairs { let attrs: KeyValuePairs = {}; if (isOpNode(node)) { attrs = {...(node.attrs || {})}; @@ -735,6 +769,10 @@ export function getAttributesFromNode(node: ModelNode): KeyValuePairs { '#descendants': `${(node.descendantsNodeIds || []).length}`, '#children': `${(node.nsChildrenIds || []).length}`, }; + const customAttrs = + modelGraph.groupNodeAttributes?.[node.id.replace('___group___', '')] || + {}; + attrs = {...attrs, ...customAttrs}; } return attrs; } @@ -749,7 +787,7 @@ export function getAttrValueRangeMatchesForNode( ): SearchMatch[] { const matches: SearchMatch[] = []; - const attrs = getAttributesFromNode(node); + const attrs = getAttributesFromNode(node, modelGraph); const value = attrs[attrName]; if (value != null) { const numValue = Number(value); diff --git a/src/ui/src/components/visualizer/expandable_info_text.ng.html b/src/ui/src/components/visualizer/expandable_info_text.ng.html index 19aa834d..aca69fcd 100644 --- a/src/ui/src/components/visualizer/expandable_info_text.ng.html +++ b/src/ui/src/components/visualizer/expandable_info_text.ng.html @@ -21,6 +21,7 @@ [style.color]="textColor" [class.has-bg-color]="hasBgColor" [class.has-overflow]="hasOverflow" + [class.has-multiple-lines]="hasMultipleLines" [class.expanded]="expanded" (click)="handleToggleExpand($event, true)">
@@ -62,7 +63,7 @@ } @else if (type === 'quantization') {
{{formatQuantization}}
} @else { - {{text}} +
{{text}}
}
diff --git a/src/ui/src/components/visualizer/expandable_info_text.scss b/src/ui/src/components/visualizer/expandable_info_text.scss index 73d25e61..89eb484d 100644 --- a/src/ui/src/components/visualizer/expandable_info_text.scss +++ b/src/ui/src/components/visualizer/expandable_info_text.scss @@ -26,7 +26,8 @@ align-items: flex-start; position: relative; - &.has-overflow { + &.has-overflow, + &.has-multiple-lines { cursor: pointer; .icon-container { @@ -73,6 +74,12 @@ overflow: auto; } + .text-content { + white-space: pre-wrap; + max-height: 500px; + overflow: auto; + } + .namespace-content { display: flex; flex-direction: column; diff --git a/src/ui/src/components/visualizer/expandable_info_text.ts b/src/ui/src/components/visualizer/expandable_info_text.ts index cc42fd51..526d3844 100644 --- a/src/ui/src/components/visualizer/expandable_info_text.ts +++ b/src/ui/src/components/visualizer/expandable_info_text.ts @@ -93,7 +93,7 @@ export class ExpandableInfoText implements AfterViewInit, OnDestroy, OnChanges { } handleToggleExpand(event: MouseEvent, fromExpandedText = false) { - if (!this.hasOverflow) { + if (!this.hasOverflow && !this.hasMultipleLines) { return; } @@ -116,6 +116,10 @@ export class ExpandableInfoText implements AfterViewInit, OnDestroy, OnChanges { return this.hasOverflowInternal; } + get hasMultipleLines(): boolean { + return this.type !== 'namespace' && this.text.includes('\n'); + } + get iconName(): string { return this.expanded ? 'unfold_less' : 'unfold_more'; } diff --git a/src/ui/src/components/visualizer/info_panel.ts b/src/ui/src/components/visualizer/info_panel.ts index 633f4a77..0184da76 100644 --- a/src/ui/src/components/visualizer/info_panel.ts +++ b/src/ui/src/components/visualizer/info_panel.ts @@ -73,6 +73,7 @@ enum SectionLabel { GRAPH_INFO = 'Graph info', NODE_INFO = 'Node info', LAYER_INFO = 'Layer info', + LAYER_ATTRS = 'Layer attributes', ATTRIBUTES = 'Attributes', NODE_DATA_PROVIDERS = 'Node data providers', IDENTICAL_GROUPS = 'Identical groups', @@ -410,7 +411,9 @@ export class InfoPanel { } getSectionToggleIcon(sectionName: string): string { - return this.isSectionCollapsed(sectionName) ? 'expand_more' : 'expand_less'; + return this.isSectionCollapsed(sectionName) + ? 'chevron_right' + : 'expand_more'; } handleLocateNode(nodeId: string, event: MouseEvent) { @@ -647,6 +650,18 @@ export class InfoPanel { value: String(layerCount), }, ); + + // Custom attributes. + const graphAttributes = this.curModelGraph.groupNodeAttributes?.['']; + if (graphAttributes) { + for (const key of Object.keys(graphAttributes)) { + graphSection.items.push({ + section: graphSection, + label: key, + value: graphAttributes[key], + }); + } + } } private genInfoDataForSelectedOpNode() { @@ -892,6 +907,27 @@ export class InfoPanel { showOnNode: this.curShowOnGroupNodeInfoIds.has(label), }); + // Section for custom attributes. + const groupAttributes = + this.curModelGraph.groupNodeAttributes?.[ + groupNode.id.replace('___group___', '') + ]; + if (groupAttributes) { + const attrsSection: InfoSection = { + label: SectionLabel.LAYER_ATTRS, + sectionType: 'group', + items: [], + }; + this.sections.push(attrsSection); + for (const key of Object.keys(groupAttributes)) { + attrsSection.items.push({ + section: nodeSection, + label: key, + value: groupAttributes[key], + }); + } + } + // Section for identical groups. if (groupNode.identicalGroupIndex != null) { this.identicalGroupNodes = this.curModelGraph.nodes.filter( diff --git a/src/ui/src/components/visualizer/view_on_node.ng.html b/src/ui/src/components/visualizer/view_on_node.ng.html index c14772cc..c6755051 100644 --- a/src/ui/src/components/visualizer/view_on_node.ng.html +++ b/src/ui/src/components/visualizer/view_on_node.ng.html @@ -53,9 +53,9 @@
0) { @@ -425,6 +428,19 @@ export function getNodeWidth( maxAttrLabelWidth = Math.max(maxAttrLabelWidth, attrLabelWidth); maxAttrValueWidth = Math.max(maxAttrValueWidth, attrValueWidth); } + + // Attrs. + if (showOnNodeItemTypes[ShowOnNodeItemType.LAYER_NODE_ATTRS]?.selected) { + const keyValuePairs = getGroupNodeAttrsKeyValuePairsForAttrsTable( + node, + modelGraph, + showOnNodeItemTypes[ShowOnNodeItemType.LAYER_NODE_ATTRS]?.filterRegex || + '', + ); + const widths = getMaxAttrLabelAndValueWidth(keyValuePairs); + maxAttrLabelWidth = Math.max(maxAttrLabelWidth, widths.maxAttrLabelWidth); + maxAttrValueWidth = Math.max(maxAttrValueWidth, widths.maxAttrValueWidth); + } } maxAttrValueWidth = Math.min( maxAttrValueWidth, @@ -467,7 +483,11 @@ export function getNodeHeight( nodeDataProviderRuns, ); } else if (isGroupNode(node)) { - attrsTableRowCount = getGroupNodeAttrsTableRowCount(showOnNodeItemTypes); + attrsTableRowCount = getGroupNodeAttrsTableRowCount( + node, + modelGraph, + showOnNodeItemTypes, + ); } return ( @@ -608,11 +628,24 @@ function getOpNodeAttrsTableRowCount( } function getGroupNodeAttrsTableRowCount( + node: GroupNode, + modelGraph: ModelGraph, showOnNodeItemTypes: Record, ): number { const baiscFieldIds = getGroupNodeFieldLabelsFromShowOnNodeItemTypes(showOnNodeItemTypes); - return baiscFieldIds.length; + + // Node attributes. + const attrsCount = showOnNodeItemTypes[ShowOnNodeItemType.LAYER_NODE_ATTRS] + ?.selected + ? getGroupNodeAttrsKeyValuePairsForAttrsTable( + node, + modelGraph, + showOnNodeItemTypes[ShowOnNodeItemType.LAYER_NODE_ATTRS]?.filterRegex || + '', + ).length + : 0; + return baiscFieldIds.length + attrsCount; } function addLayoutGraphEdge( diff --git a/src/ui/src/components/visualizer/worker/graph_processor.ts b/src/ui/src/components/visualizer/worker/graph_processor.ts index efcb2aa2..8179c858 100644 --- a/src/ui/src/components/visualizer/worker/graph_processor.ts +++ b/src/ui/src/components/visualizer/worker/graph_processor.ts @@ -678,7 +678,7 @@ export class GraphProcessor { } createEmptyModelGraph(): ModelGraph { - return { + const modelGraph: ModelGraph = { id: this.graph.id, collectionLabel: this.graph.collectionLabel || '', nodes: [], @@ -689,6 +689,11 @@ export class GraphProcessor { minDescendantOpNodeCount: -1, maxDescendantOpNodeCount: -1, }; + if (this.graph.groupNodeAttributes) { + modelGraph.groupNodeAttributes = this.graph.groupNodeAttributes; + } + + return modelGraph; } private getAncestorNamespaces(ns: string): string[] {