Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions golang/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/MetaGLM/glm-realtime-sdk/golang/events"
"github.com/MetaGLM/glm-realtime-sdk/golang/tools"
"github.com/gorilla/websocket"
)

Expand Down Expand Up @@ -113,6 +114,37 @@ func (r *realtimeClient) Send(event *events.Event) (err error) {
return err
}

func (r *realtimeClient) SendFrameByVideo(event *events.Event) (err error) {
if events.RealtimeClientVideoAppend != event.Type {
return fmt.Errorf("event type is not RealtimeClientVideoAppend")
}
if event.VideoFrame == nil {
return fmt.Errorf("event videoFrame is nil")
}

r.lock.RLock()
defer r.lock.RUnlock()
if !r.isConnected {
log.Printf("[RealtimeClient] Sending event fail, err: not connected\n")
return fmt.Errorf("not connected")
}
if event.ClientTimestamp <= 0 {
event.ClientTimestamp = time.Now().UnixMilli()
}
frames, err := tools.ExtractFramesToBase64(event.VideoFrame, "Z0LADJoFAAABMA==", "aM48gA==")
if err != nil {
return fmt.Errorf("extract frames failed: %v", err)
}
for index := range frames {
event.VideoFrame = frames[index]
if err = r.conn.WriteMessage(websocket.TextMessage, []byte(event.ToJson())); err != nil {
log.Printf("[RealtimeClient] Send failed, error: %v\n", err)
return err
}
}
return nil
}

func (r *realtimeClient) readWsMsg() {
defer r.wg.Done()
deadline := time.Now().Add(waitTimeout)
Expand Down
1 change: 1 addition & 0 deletions golang/events/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const (
RealtimeClientEventSessionUpdate EventType = "session.update"
RealtimeClientEventTranscriptionSessionUpdate EventType = "transcription_session.update"
RealtimeClientEventInputAudioBufferAppend EventType = "input_audio_buffer.append"
RealtimeClientVideoAppend EventType = "input_audio_buffer.append_video_frame"
RealtimeClientEventInputAudioBufferCommit EventType = "input_audio_buffer.commit"
RealtimeClientEventInputAudioBufferClear EventType = "input_audio_buffer.clear"
RealtimeClientEventConversationItemCreate EventType = "conversation.item.create"
Expand Down
117 changes: 116 additions & 1 deletion golang/tools/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package tools

import (
"bytes"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"log"
"os"
"os/exec"
"path/filepath"

"github.com/go-audio/audio"
"github.com/go-audio/wav"
Expand Down Expand Up @@ -120,4 +124,115 @@ func Pcm2Wav(pcmBytes []byte, sampleRate, numChannels, bitDepth int) ([]byte, er
copy(wavData[44:], pcmBytes)

return wavData, nil
}
}

// ExtractFramesToBase64 接收 base64 编码的 H.264 数据,返回抽帧后图片的 base64 数组
func ExtractFramesToBase64(data []byte, spsB64, ppsB64 string) ([][]byte, error) {
var images [][]byte
// 创建临时目录
tempDir, err := os.MkdirTemp("", "video_process_*")
if err != nil {
return nil, fmt.Errorf("failed to create temp dir: %v", err)
}
defer func(path string) {
err := os.RemoveAll(path)
if err != nil {
log.Printf("failed to remove temp dir: %v", err)
}
}(tempDir) // 自动清理

// 1. 解码 base64 到 .h264 文件
h264Path := filepath.Join(tempDir, "input.h264")
// 注入 SPS/PPS
fixedData, err := InjectSPSPPS(data, spsB64, ppsB64)
if err != nil {
log.Fatal(err)
}

if err := os.WriteFile(h264Path, fixedData, 0644); err != nil {
return nil, fmt.Errorf("write h264 file failed: %v", err)
}

// 2. 设置输出帧路径
framePattern := filepath.Join(tempDir, "frame_%04d.jpg")

// 3. 调用 ffmpeg 抽帧
cmd := exec.Command(
"ffmpeg",
"-f", "h264",
"-i", h264Path,
"-vf", "fps=2", // 每秒 2 帧
"-qscale:v", "2", // 高质量 JPEG
"-y", // 允许覆盖
framePattern,
)

// 捕获输出用于调试(可选)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

log.Printf("Running command: %v", cmd.Args)
err = cmd.Run()
if err != nil {
return nil, fmt.Errorf("ffmpeg execution failed: %v", err)
}

// 4. 查找所有生成的 jpg 文件并转为 base64
matches, err := filepath.Glob(filepath.Join(tempDir, "frame_*.jpg"))
if err != nil {
return nil, fmt.Errorf("glob pattern error: %v", err)
}

// 按文件名排序(保证顺序)
sortFiles(matches)

for _, imgPath := range matches {
imgData, err := os.ReadFile(imgPath) // 替代 ioutil.ReadFile
if err != nil {
return nil, fmt.Errorf("read image file failed: %v", err)
}
images = append(images, imgData)
}

log.Printf("Successfully extracted %d frames.", len(images))
return images, nil
}

func InjectSPSPPS(rawH264 []byte, b64SPS, b64PPS string) ([]byte, error) {
sps, err := base64.StdEncoding.DecodeString(b64SPS)
if err != nil {
return nil, fmt.Errorf("decode SPS failed: %v", err)
}
pps, err := base64.StdEncoding.DecodeString(b64PPS)
if err != nil {
return nil, fmt.Errorf("decode PPS failed: %v", err)
}

// 构造完整数据:[start code][SPS][start code][PPS][原始数据]
var result []byte

// 写入 SPS
result = append(result, 0x00, 0x00, 0x00, 0x01)
result = append(result, sps...)

// 写入 PPS
result = append(result, 0x00, 0x00, 0x00, 0x01)
result = append(result, pps...)

// 写入原始数据(即你现在拿到的 Type 1 流)
result = append(result, rawH264...)

return result, nil
}

// sortFiles 简单排序文件名(如 frame_0001.jpg, frame_0002.jpg)
func sortFiles(files []string) {
// 使用标准库排序
for i := 0; i < len(files); i++ {
for j := i + 1; j < len(files); j++ {
if files[i] > files[j] {
files[i], files[j] = files[j], files[i]
}
}
}
}
24 changes: 24 additions & 0 deletions golang/tools/tools_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package tools

import (
"encoding/base64"
"fmt"
"log"
"testing"
)

func TestExtractFramesToBase64(t *testing.T) {
video := ""
// 解码为 []byte
data, err := base64.StdEncoding.DecodeString(video)
if err != nil {
log.Fatal("解码失败:", err)
}
frames, err := ExtractFramesToBase64(data, "Z0LADJoFAAABMA==", "aM48gA==")
if err != nil {
panic(err)
}
for _, frame := range frames {
fmt.Println(frame)
}
}