118 lines
3.2 KiB
TypeScript
118 lines
3.2 KiB
TypeScript
import { ApiError } from "@/lib/types"
|
|
import { useState, useCallback, useRef } from "react"
|
|
import { root } from '@repo/core'
|
|
import { AigcController, SubmitTaskBody, GetTaskStatusResult } from "@repo/sdk"
|
|
import { useError } from "../data/use-error"
|
|
|
|
export const useAigcTask = () => {
|
|
const [loading, setLoading] = useState(false)
|
|
const [error, setError] = useState<ApiError | null>(null)
|
|
const [taskId, setTaskId] = useState<string | null>(null)
|
|
const [taskStatus, setTaskStatus] = useState<GetTaskStatusResult | null>(null)
|
|
const [isPolling, setIsPolling] = useState(false)
|
|
|
|
const pollingIntervalRef = useRef<ReturnType<typeof setInterval> | null>(null)
|
|
|
|
const submitTask = useCallback(async (params: SubmitTaskBody) => {
|
|
setLoading(true)
|
|
setError(null)
|
|
setTaskId(null)
|
|
setTaskStatus(null)
|
|
|
|
const aigc = root.get(AigcController)
|
|
const { data, error } = await useError(async () => await aigc.submitTask(params))
|
|
|
|
if (error) {
|
|
setError(error)
|
|
setLoading(false)
|
|
return { taskId: null, error }
|
|
}
|
|
|
|
if (data?.data) {
|
|
setTaskId(data.data)
|
|
setLoading(false)
|
|
return { taskId: data.data, error: null }
|
|
}
|
|
|
|
setLoading(false)
|
|
return { taskId: null, error: { message: "提交任务失败" } as ApiError }
|
|
}, [])
|
|
|
|
const checkStatus = useCallback(async (id: string) => {
|
|
const aigc = root.get(AigcController)
|
|
const { data, error } = await useError(async () => await aigc.getTaskStatus(id))
|
|
|
|
if (error) {
|
|
setError(error)
|
|
return { status: null, error }
|
|
}
|
|
|
|
if (data) {
|
|
setTaskStatus(data)
|
|
return { status: data, error: null }
|
|
}
|
|
|
|
return { status: null, error: { message: "获取状态失败" } as ApiError }
|
|
}, [])
|
|
|
|
const startPolling = useCallback(
|
|
(id: string, onComplete?: (result: GetTaskStatusResult) => void, onError?: (error: ApiError) => void) => {
|
|
if (isPolling) return
|
|
|
|
setIsPolling(true)
|
|
setError(null)
|
|
|
|
const poll = async () => {
|
|
const { status, error } = await checkStatus(id)
|
|
|
|
if (error) {
|
|
stopPolling()
|
|
setIsPolling(false)
|
|
onError?.(error)
|
|
return
|
|
}
|
|
|
|
if (status) {
|
|
const isCompleted = status.status === true || status.status === "success"
|
|
const isFailed = status.status === "failed" || status.status === false
|
|
|
|
if (isCompleted) {
|
|
stopPolling()
|
|
setIsPolling(false)
|
|
onComplete?.(status)
|
|
} else if (isFailed) {
|
|
stopPolling()
|
|
setIsPolling(false)
|
|
const errorMsg = typeof status.msg === "string" ? status.msg : status.msg.message
|
|
onError?.({ message: errorMsg } as ApiError)
|
|
}
|
|
}
|
|
}
|
|
|
|
poll()
|
|
pollingIntervalRef.current = setInterval(poll, 3000)
|
|
},
|
|
[checkStatus, isPolling]
|
|
)
|
|
|
|
const stopPolling = useCallback(() => {
|
|
if (pollingIntervalRef.current) {
|
|
clearInterval(pollingIntervalRef.current)
|
|
pollingIntervalRef.current = null
|
|
}
|
|
setIsPolling(false)
|
|
}, [])
|
|
|
|
return {
|
|
loading,
|
|
error,
|
|
taskId,
|
|
taskStatus,
|
|
isPolling,
|
|
submitTask,
|
|
checkStatus,
|
|
startPolling,
|
|
stopPolling,
|
|
}
|
|
}
|