跳转至

mlflow

feishu.plugins.mlflow

MLflowBundle

无额外依赖的 MLflow / tracker run 载荷归一化工具 bundle。

源代码位于: feishu/plugins/mlflow.py
Python
class MLflowBundle:
    r"""无额外依赖的 MLflow / tracker run 载荷归一化工具 bundle。"""

    def register(self, registry: ToolRegistry, context: BundleContext) -> None:
        registry.add(
            Tool(
                name="normalize_mlflow_run",
                description="把 MLflow run 或 tracker 事件载荷归一化为紧凑 run 事实。",
                input_schema={
                    "type": "object",
                    "properties": {"payload": {"type": "object", "description": "MLflow run 或 tracker 事件 JSON。"}},
                    "required": ["payload"],
                    "additionalProperties": False,
                },
                handler=lambda payload: normalize_mlflow_run(payload),
            )
        )
        registry.add(
            Tool(
                name="search_mlflow_experiments",
                description="搜索 MLflow experiments,用于定位训练任务所在实验。",
                input_schema={
                    "type": "object",
                    "properties": {
                        "query": {"type": "string", "description": "实验名称关键词。"},
                        "max_results": {"type": "integer", "description": "最多返回多少个实验,默认 20。"},
                    },
                    "additionalProperties": False,
                },
                handler=lambda query=None, max_results=20: _search_experiments(
                    context, query=query, max_results=max_results
                ),
            )
        )
        registry.add(
            Tool(
                name="search_mlflow_runs",
                description="搜索 MLflow runs,用于查看训练任务状态、指标、参数和最近运行。",
                input_schema={
                    "type": "object",
                    "properties": {
                        "experiment_ids": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "MLflow experiment_id 列表。",
                        },
                        "experiment_name": {"type": "string", "description": "实验名称;未给 ids 时可用。"},
                        "filter": {"type": "string", "description": "MLflow runs/search filter 字符串。"},
                        "max_results": {"type": "integer", "description": "最多返回多少个 run,默认 10。"},
                        "order_by": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "MLflow order_by 列表。",
                        },
                    },
                    "additionalProperties": False,
                },
                handler=lambda experiment_ids=None, experiment_name=None, filter=None, max_results=10, order_by=None: (
                    _search_runs(
                        context,
                        experiment_ids=experiment_ids,
                        experiment_name=experiment_name,
                        filter=filter,
                        max_results=max_results,
                        order_by=order_by,
                    )
                ),
            )
        )
        registry.add(
            Tool(
                name="get_mlflow_run",
                description="读取单个 MLflow run 的状态、指标、参数、标签和 artifact URI。",
                input_schema={
                    "type": "object",
                    "properties": {"run_id": {"type": "string", "description": "MLflow run_id。"}},
                    "required": ["run_id"],
                    "additionalProperties": False,
                },
                handler=lambda run_id: _get_run(context, run_id=run_id),
            )
        )

MLflowClient

Small async MLflow Tracking REST client for read-only run inspection.

源代码位于: feishu/plugins/mlflow.py
Python
class MLflowClient:
    r"""Small async MLflow Tracking REST client for read-only run inspection."""

    def __init__(
        self,
        base_url: str,
        *,
        username: str | None = None,
        api_token: str | None = None,
        timeout: float = 10.0,
        transport: httpx.AsyncBaseTransport | None = None,
    ) -> None:
        self.base_url = base_url.rstrip("/")
        self.username = username
        self.api_token = api_token
        self.timeout = timeout
        self.transport = transport

    async def search_experiments(self, *, query: str | None = None, max_results: int = 20) -> list[dict[str, Any]]:
        payload = await self._request(
            "POST",
            "/api/2.0/mlflow/experiments/search",
            json={"max_results": max(1, min(int(max_results), 100)), "view_type": "ALL"},
        )
        payload_dict = payload if isinstance(payload, dict) else {}
        experiments_value = payload_dict.get("experiments")
        experiments = experiments_value if isinstance(experiments_value, list) else []
        items = [_compact_experiment(item) for item in experiments if isinstance(item, dict)]
        if query:
            lowered = query.lower()
            items = [item for item in items if lowered in item.get("name", "").lower()]
        return items

    async def search_runs(
        self,
        *,
        experiment_ids: list[str] | None = None,
        experiment_name: str | None = None,
        filter: str | None = None,
        max_results: int = 10,
        order_by: list[str] | None = None,
    ) -> list[dict[str, Any]]:
        ids = list(experiment_ids or [])
        if not ids and experiment_name:
            experiment = await self._experiment_by_name(experiment_name)
            experiment_id = _text(experiment.get("experiment_id"))
            if experiment_id:
                ids = [experiment_id]
        if not ids:
            experiments = await self.search_experiments(max_results=100)
            ids = [item["experiment_id"] for item in experiments if item["experiment_id"]]
        payload: dict[str, Any] = {
            "experiment_ids": ids,
            "max_results": max(1, min(int(max_results), 100)),
            "run_view_type": "ALL",
        }
        if filter:
            payload["filter"] = filter
        if order_by:
            payload["order_by"] = order_by
        result = await self._request("POST", "/api/2.0/mlflow/runs/search", json=payload)
        result_dict = result if isinstance(result, dict) else {}
        runs_value = result_dict.get("runs")
        runs = runs_value if isinstance(runs_value, list) else []
        return [normalize_mlflow_run(run) for run in runs if isinstance(run, dict)]

    async def get_run(self, run_id: str) -> dict[str, Any]:
        payload = await self._request("GET", "/api/2.0/mlflow/runs/get", params={"run_id": run_id})
        return normalize_mlflow_run(payload)

    async def _experiment_by_name(self, name: str) -> dict[str, Any]:
        payload = await self._request(
            "GET",
            "/api/2.0/mlflow/experiments/get-by-name",
            params={"experiment_name": name},
        )
        experiment = payload.get("experiment") if isinstance(payload, dict) else {}
        return experiment if isinstance(experiment, dict) else {}

    async def _request(self, method: str, path: str, **kwargs: Any) -> Any:
        headers = {"Accept": "application/json"}
        if self.username and self.api_token:
            encoded = base64.b64encode(f"{self.username}:{self.api_token}".encode()).decode()
            headers["Authorization"] = f"Basic {encoded}"
        elif self.api_token:
            headers["Authorization"] = f"Bearer {self.api_token}"
        async with httpx.AsyncClient(
            base_url=self.base_url,
            timeout=self.timeout,
            headers=headers,
            transport=self.transport,
        ) as client:
            response = await client.request(method, path, **kwargs)
            response.raise_for_status()
            return response.json()

normalize_mlflow_run

Python
normalize_mlflow_run(payload: dict[str, Any]) -> dict[str, Any]

把 MLflow run 或 tracker 事件载荷归一化为适合模型消费的 run 事实。

源代码位于: feishu/plugins/mlflow.py
Python
def normalize_mlflow_run(payload: dict[str, Any]) -> dict[str, Any]:
    r"""把 MLflow run 或 tracker 事件载荷归一化为适合模型消费的 run 事实。"""
    run_value = payload.get("run")
    run = dict(run_value) if isinstance(run_value, dict) else payload
    info_value = run.get("info")
    info = dict(info_value) if isinstance(info_value, dict) else run
    data = _dict(run.get("data"))
    tags = _key_value_dict(data.get("tags") or run.get("tags"))
    return {
        "run_id": _text(info.get("run_id") or info.get("run_uuid") or run.get("run_id")),
        "experiment_id": _text(info.get("experiment_id") or run.get("experiment_id")),
        "status": _text(info.get("status") or run.get("status")),
        "name": _text(tags.get("mlflow.runName") or tags.get("run_name") or info.get("run_name")),
        "artifact_uri": _text(info.get("artifact_uri") or run.get("artifact_uri")),
        "metrics": _key_value_dict(data.get("metrics") or run.get("metrics")),
        "params": _key_value_dict(data.get("params") or run.get("params")),
        "tags": tags,
    }