add broadcasts to api clie

nt
This commit is contained in:
2026-04-29 09:39:16 +02:00
parent 58219ab0eb
commit 1193f128bb
3 changed files with 109 additions and 30 deletions
+25 -2
View File
@@ -6,6 +6,7 @@ export class WSBackend {
callback: (err: any, res: any) => void;
timeout: ReturnType<typeof setTimeout>;
}> = {};
private broadcastListeners: Map<string, Set<(data: any) => void>> = new Map();
private counter = 0;
private _api: any = null;
private queue: Array<{ id: string; method: string; params: any[]; resolve: Function; reject: Function }> = [];
@@ -13,7 +14,7 @@ export class WSBackend {
private _connected = false;
private connectedListeners: Array<() => void> = [];
private disconnectedListeners: Array<() => void> = [];
private callTimeout = 10000; // 10 second timeout for calls
private callTimeout = 10000;
constructor(url: string) {
this.url = url;
@@ -47,6 +48,20 @@ export class WSBackend {
this.callTimeout = ms;
}
public subscribe<K extends keyof BroadcastEvents>(topic: K, callback: (data: BroadcastEvents[K]) => void): () => void {
const topicStr = String(topic);
if (!this.broadcastListeners.has(topicStr)) {
this.broadcastListeners.set(topicStr, new Set());
}
const cb = callback as (data: any) => void;
this.broadcastListeners.get(topicStr)!.add(cb);
return () => {
this.broadcastListeners.get(topicStr)?.delete(cb);
};
}
private connect() {
this.ws = new WebSocket(this.url);
@@ -69,6 +84,14 @@ export class WSBackend {
return;
}
if (msg.topic) {
const listeners = this.broadcastListeners.get(msg.topic);
if (listeners) {
listeners.forEach(cb => cb(msg.data));
}
return;
}
const callbackData = this.callbacks[msg.id];
if (!callbackData) return;
@@ -117,7 +140,7 @@ export class WSBackend {
return new Promise((resolve) => {
const timeout = setTimeout(() => {
if (this.callbacks[id]) {
console.warn(`[WS] Call timeout for \${method} (id: \${id})`);
console.warn(`[WS] Call timeout for ${method} (id: ${id})`);
this.callbacks[id].callback({ message: 'Request timeout' }, null);
delete this.callbacks[id];
}
+30 -21
View File
@@ -15,9 +15,9 @@ var backendClient string
var errorType = reflect.TypeOf((*error)(nil)).Elem()
func GenClient(api any, outPutPath string) error {
func GenClient(api any, broadcasts any, outPutPath string) error {
ts, err := GenerateTS(reflect.TypeOf(api), "RPCClient")
ts, err := GenerateTS(reflect.TypeOf(api), reflect.TypeOf(broadcasts), "RPCClient")
if err != nil {
return err
}
@@ -27,20 +27,39 @@ func GenClient(api any, outPutPath string) error {
return nil
}
func GenerateTS(apiType reflect.Type, clientName string) (string, error) {
func GenerateTS(apiType reflect.Type, broadcastType reflect.Type, clientName string) (string, error) {
var b strings.Builder
b.WriteString("// --- AUTO-GENERATED ---\n")
b.WriteString("// Generated by generator. Do not edit by hand (unless you know what you do).\n")
b.WriteString("// Types\n\n")
structs := map[string]reflect.Type{}
b.WriteString("export interface BroadcastEvents {\n")
if broadcastType != nil {
bt := broadcastType
if bt.Kind() == reflect.Ptr {
bt = bt.Elem()
}
if bt.Kind() == reflect.Struct {
for i := 0; i < bt.NumField(); i++ {
f := bt.Field(i)
topicName := f.Name
tag := f.Tag.Get("json")
if tag != "" {
topicName = strings.Split(tag, ",")[0]
}
collectStructs(f.Type, structs)
b.WriteString(fmt.Sprintf(" '%s': %s;\n", topicName, goTypeToTS(f.Type)))
}
}
}
b.WriteString("}\n\n")
b.WriteString("export type BroadcastTopic = keyof BroadcastEvents;\n\n")
for method := range apiType.Methods() {
i := -1
for param := range method.Type.Ins() {
i++
//skip self
if i == 0 {
continue
}
@@ -48,14 +67,11 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) {
}
outParams := method.Type.NumOut()
unsupportedMethod := "not supported, allowed layout are \nfunc() (error) \nfunc() (struct, error), \nfunc(x...) (struct, error)"
//func()
if outParams == 0 {
return "", errors.New(method.Name + unsupportedMethod)
}
//func() (err)
if outParams == 1 {
if method.Type.Out(0).Implements(errorType) {
collectStructs(method.Type.Out(0), structs)
@@ -63,23 +79,18 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) {
return "", errors.New(method.Name + unsupportedMethod)
}
}
//func() (struct, err)
if outParams > 1 && method.Type.Out(1).Implements(errorType) {
collectStructs(method.Type.Out(0), structs)
}
if outParams > 2 {
return "", errors.New(method.Name + unsupportedMethod)
}
}
names := make([]string, 0, len(structs))
for n := range structs {
names = append(names, n)
}
slices.Sort(names)
for _, name := range names {
@@ -106,7 +117,6 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) {
b.WriteString("}\n\n")
}
b.WriteString("// Generic RPC method result\n")
b.WriteString("export type RPCResult<T> = { data: T; error?: any };\n\n")
b.WriteString(fmt.Sprintf("export interface %s {\n", clientName))
@@ -141,7 +151,7 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) {
if m.Type.NumOut() > 0 {
resType = goTypeToTS(m.Type.Out(0))
if m.Type.Out(0).Implements(errorType) {
resType = "any" // It's just a func() error
resType = "any"
}
}
@@ -168,7 +178,6 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) {
return b.String(), nil
}
func collectStructs(t reflect.Type, m map[string]reflect.Type) {
if t == nil {
return
+51 -4
View File
@@ -1,6 +1,7 @@
package rpc
import (
"fmt"
"os"
"reflect"
"strings"
@@ -29,6 +30,12 @@ func (s *MockAPI) GetUser(id int) (User, error) { return User{}, nil }
func (s *MockAPI) CreateOrder(o Order) error { return nil }
func (s *MockAPI) NoArgs() error { return nil }
type MockBroadcasts struct {
UserUpdated User `json:"user_updated"`
SystemAlert string `json:"system_alert"`
Tick float64 `json:"tick"`
}
func TestGoTypeToTS(t *testing.T) {
tests := []struct {
name string
@@ -75,7 +82,7 @@ func TestCollectStructs(t *testing.T) {
func TestGenerateTS_Success(t *testing.T) {
api := &MockAPI{}
output, err := GenerateTS(reflect.TypeOf(api), "TestClient")
output, err := GenerateTS(reflect.TypeOf(api), nil, "TestClient")
if err != nil {
t.Fatalf("GenerateTS failed: %v", err)
}
@@ -89,6 +96,30 @@ func TestGenerateTS_Success(t *testing.T) {
"CreateOrder(order: Order): Promise<RPCResult<any>>;",
"useBackend(url: string = '/ws')",
}
fmt.Println(output)
for _, s := range expectedStrings {
if !strings.Contains(output, s) {
t.Errorf("Generated TS missing expected string: %s", s)
}
}
}
func TestGenerateTS_WithBroadcasts(t *testing.T) {
api := &MockAPI{}
broadcasts := &MockBroadcasts{}
output, err := GenerateTS(reflect.TypeOf(api), reflect.TypeOf(broadcasts), "TestClient")
if err != nil {
t.Fatalf("GenerateTS failed: %v", err)
}
expectedStrings := []string{
"export interface BroadcastEvents {",
"'user_updated': User;",
"'system_alert': string;",
"'tick': number;",
"export type BroadcastTopic = keyof BroadcastEvents;",
}
for _, s := range expectedStrings {
if !strings.Contains(output, s) {
@@ -110,7 +141,7 @@ func TestGenerateTS_Validation(t *testing.T) {
t.Run("InvalidReturnCount", func(t *testing.T) {
api := &APIWithTooManyReturns{}
_, err := GenerateTS(reflect.TypeOf(api), "Client")
_, err := GenerateTS(reflect.TypeOf(api), nil, "Client")
if err == nil {
t.Error("Expected error for method with 3 return values")
}
@@ -118,7 +149,7 @@ func TestGenerateTS_Validation(t *testing.T) {
t.Run("NoResidentError", func(t *testing.T) {
api := &APIWithNoErrorHandler{}
_, err := GenerateTS(reflect.TypeOf(api), "Client")
_, err := GenerateTS(reflect.TypeOf(api), nil, "Client")
if err == nil {
t.Error("Expected error for method missing 'error' return type")
}
@@ -130,7 +161,23 @@ func TestGenClient(t *testing.T) {
defer os.Remove(tmpFile)
api := &MockAPI{}
err := GenClient(api, tmpFile)
err := GenClient(api, nil, tmpFile)
if err != nil {
t.Fatalf("GenClient failed: %v", err)
}
if _, err := os.Stat(tmpFile); os.IsNotExist(err) {
t.Error("GenClient did not create the file")
}
}
func TestGenClient_WithBroadcasts(t *testing.T) {
tmpFile := "test_client_broadcast.ts"
defer os.Remove(tmpFile)
api := &MockAPI{}
broadcasts := &MockBroadcasts{}
err := GenClient(api, broadcasts, tmpFile)
if err != nil {
t.Fatalf("GenClient failed: %v", err)
}