diff --git a/golang/client/client.go b/golang/client/client.go index 3a5d93a..3b3a04e 100644 --- a/golang/client/client.go +++ b/golang/client/client.go @@ -9,6 +9,7 @@ import ( "time" "github.com/MetaGLM/glm-realtime-sdk/golang/events" + "github.com/MetaGLM/glm-realtime-sdk/golang/tools" "github.com/gorilla/websocket" ) @@ -113,6 +114,37 @@ func (r *realtimeClient) Send(event *events.Event) (err error) { return err } +func (r *realtimeClient) SendFrameByVideo(event *events.Event) (err error) { + if events.RealtimeClientVideoAppend != event.Type { + return fmt.Errorf("event type is not RealtimeClientVideoAppend") + } + if event.VideoFrame == nil { + return fmt.Errorf("event videoFrame is nil") + } + + r.lock.RLock() + defer r.lock.RUnlock() + if !r.isConnected { + log.Printf("[RealtimeClient] Sending event fail, err: not connected\n") + return fmt.Errorf("not connected") + } + if event.ClientTimestamp <= 0 { + event.ClientTimestamp = time.Now().UnixMilli() + } + frames, err := tools.ExtractFramesToBase64(event.VideoFrame, "Z0LADJoFAAABMA==", "aM48gA==") + if err != nil { + return fmt.Errorf("extract frames failed: %v", err) + } + for index := range frames { + event.VideoFrame = frames[index] + if err = r.conn.WriteMessage(websocket.TextMessage, []byte(event.ToJson())); err != nil { + log.Printf("[RealtimeClient] Send failed, error: %v\n", err) + return err + } + } + return nil +} + func (r *realtimeClient) readWsMsg() { defer r.wg.Done() deadline := time.Now().Add(waitTimeout) diff --git a/golang/events/event.go b/golang/events/event.go index 55f9b03..e2dd43c 100644 --- a/golang/events/event.go +++ b/golang/events/event.go @@ -11,6 +11,7 @@ const ( RealtimeClientEventSessionUpdate EventType = "session.update" RealtimeClientEventTranscriptionSessionUpdate EventType = "transcription_session.update" RealtimeClientEventInputAudioBufferAppend EventType = "input_audio_buffer.append" + RealtimeClientVideoAppend EventType = "input_audio_buffer.append_video_frame" RealtimeClientEventInputAudioBufferCommit EventType = "input_audio_buffer.commit" RealtimeClientEventInputAudioBufferClear EventType = "input_audio_buffer.clear" RealtimeClientEventConversationItemCreate EventType = "conversation.item.create" diff --git a/golang/tools/tools.go b/golang/tools/tools.go index d3d1d02..e5f5524 100644 --- a/golang/tools/tools.go +++ b/golang/tools/tools.go @@ -2,10 +2,14 @@ package tools import ( "bytes" + "encoding/base64" "encoding/binary" "fmt" "io" + "log" "os" + "os/exec" + "path/filepath" "github.com/go-audio/audio" "github.com/go-audio/wav" @@ -120,4 +124,115 @@ func Pcm2Wav(pcmBytes []byte, sampleRate, numChannels, bitDepth int) ([]byte, er copy(wavData[44:], pcmBytes) return wavData, nil -} \ No newline at end of file +} + +// ExtractFramesToBase64 接收 base64 编码的 H.264 数据,返回抽帧后图片的 base64 数组 +func ExtractFramesToBase64(data []byte, spsB64, ppsB64 string) ([][]byte, error) { + var images [][]byte + // 创建临时目录 + tempDir, err := os.MkdirTemp("", "video_process_*") + if err != nil { + return nil, fmt.Errorf("failed to create temp dir: %v", err) + } + defer func(path string) { + err := os.RemoveAll(path) + if err != nil { + log.Printf("failed to remove temp dir: %v", err) + } + }(tempDir) // 自动清理 + + // 1. 解码 base64 到 .h264 文件 + h264Path := filepath.Join(tempDir, "input.h264") + // 注入 SPS/PPS + fixedData, err := InjectSPSPPS(data, spsB64, ppsB64) + if err != nil { + log.Fatal(err) + } + + if err := os.WriteFile(h264Path, fixedData, 0644); err != nil { + return nil, fmt.Errorf("write h264 file failed: %v", err) + } + + // 2. 设置输出帧路径 + framePattern := filepath.Join(tempDir, "frame_%04d.jpg") + + // 3. 调用 ffmpeg 抽帧 + cmd := exec.Command( + "ffmpeg", + "-f", "h264", + "-i", h264Path, + "-vf", "fps=2", // 每秒 2 帧 + "-qscale:v", "2", // 高质量 JPEG + "-y", // 允许覆盖 + framePattern, + ) + + // 捕获输出用于调试(可选) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + log.Printf("Running command: %v", cmd.Args) + err = cmd.Run() + if err != nil { + return nil, fmt.Errorf("ffmpeg execution failed: %v", err) + } + + // 4. 查找所有生成的 jpg 文件并转为 base64 + matches, err := filepath.Glob(filepath.Join(tempDir, "frame_*.jpg")) + if err != nil { + return nil, fmt.Errorf("glob pattern error: %v", err) + } + + // 按文件名排序(保证顺序) + sortFiles(matches) + + for _, imgPath := range matches { + imgData, err := os.ReadFile(imgPath) // 替代 ioutil.ReadFile + if err != nil { + return nil, fmt.Errorf("read image file failed: %v", err) + } + images = append(images, imgData) + } + + log.Printf("Successfully extracted %d frames.", len(images)) + return images, nil +} + +func InjectSPSPPS(rawH264 []byte, b64SPS, b64PPS string) ([]byte, error) { + sps, err := base64.StdEncoding.DecodeString(b64SPS) + if err != nil { + return nil, fmt.Errorf("decode SPS failed: %v", err) + } + pps, err := base64.StdEncoding.DecodeString(b64PPS) + if err != nil { + return nil, fmt.Errorf("decode PPS failed: %v", err) + } + + // 构造完整数据:[start code][SPS][start code][PPS][原始数据] + var result []byte + + // 写入 SPS + result = append(result, 0x00, 0x00, 0x00, 0x01) + result = append(result, sps...) + + // 写入 PPS + result = append(result, 0x00, 0x00, 0x00, 0x01) + result = append(result, pps...) + + // 写入原始数据(即你现在拿到的 Type 1 流) + result = append(result, rawH264...) + + return result, nil +} + +// sortFiles 简单排序文件名(如 frame_0001.jpg, frame_0002.jpg) +func sortFiles(files []string) { + // 使用标准库排序 + for i := 0; i < len(files); i++ { + for j := i + 1; j < len(files); j++ { + if files[i] > files[j] { + files[i], files[j] = files[j], files[i] + } + } + } +} diff --git a/golang/tools/tools_test.go b/golang/tools/tools_test.go new file mode 100644 index 0000000..a8b1155 --- /dev/null +++ b/golang/tools/tools_test.go @@ -0,0 +1,24 @@ +package tools + +import ( + "encoding/base64" + "fmt" + "log" + "testing" +) + +func TestExtractFramesToBase64(t *testing.T) { + video := "" + // 解码为 []byte + data, err := base64.StdEncoding.DecodeString(video) + if err != nil { + log.Fatal("解码失败:", err) + } + frames, err := ExtractFramesToBase64(data, "Z0LADJoFAAABMA==", "aM48gA==") + if err != nil { + panic(err) + } + for _, frame := range frames { + fmt.Println(frame) + } +}