diff --git a/api.ts b/api.ts index bda7881..f7ec1f3 100644 --- a/api.ts +++ b/api.ts @@ -6,6 +6,7 @@ export class WSBackend { callback: (err: any, res: any) => void; timeout: ReturnType; }> = {}; + private broadcastListeners: Map void>> = new Map(); private counter = 0; private _api: any = null; private queue: Array<{ id: string; method: string; params: any[]; resolve: Function; reject: Function }> = []; @@ -13,7 +14,7 @@ export class WSBackend { private _connected = false; private connectedListeners: Array<() => void> = []; private disconnectedListeners: Array<() => void> = []; - private callTimeout = 10000; // 10 second timeout for calls + private callTimeout = 10000; constructor(url: string) { this.url = url; @@ -47,6 +48,20 @@ export class WSBackend { this.callTimeout = ms; } + public subscribe(topic: K, callback: (data: BroadcastEvents[K]) => void): () => void { + const topicStr = String(topic); + if (!this.broadcastListeners.has(topicStr)) { + this.broadcastListeners.set(topicStr, new Set()); + } + + const cb = callback as (data: any) => void; + this.broadcastListeners.get(topicStr)!.add(cb); + + return () => { + this.broadcastListeners.get(topicStr)?.delete(cb); + }; + } + private connect() { this.ws = new WebSocket(this.url); @@ -69,6 +84,14 @@ export class WSBackend { return; } + if (msg.topic) { + const listeners = this.broadcastListeners.get(msg.topic); + if (listeners) { + listeners.forEach(cb => cb(msg.data)); + } + return; + } + const callbackData = this.callbacks[msg.id]; if (!callbackData) return; @@ -91,9 +114,9 @@ export class WSBackend { Object.keys(this.callbacks).forEach(id => { const callbackData = this.callbacks[id]; if (callbackData) { - clearTimeout(callbackData.timeout); - callbackData.callback({ message: 'WebSocket disconnected' }, null); - delete this.callbacks[id]; + clearTimeout(callbackData.timeout); + callbackData.callback({ message: 'WebSocket disconnected' }, null); + delete this.callbacks[id]; } }); @@ -117,7 +140,7 @@ export class WSBackend { return new Promise((resolve) => { const timeout = setTimeout(() => { if (this.callbacks[id]) { - console.warn(`[WS] Call timeout for \${method} (id: \${id})`); + console.warn(`[WS] Call timeout for ${method} (id: ${id})`); this.callbacks[id].callback({ message: 'Request timeout' }, null); delete this.callbacks[id]; } diff --git a/generator.go b/generator.go index f552ac4..43222fc 100644 --- a/generator.go +++ b/generator.go @@ -15,9 +15,9 @@ var backendClient string var errorType = reflect.TypeOf((*error)(nil)).Elem() -func GenClient(api any, outPutPath string) error { +func GenClient(api any, broadcasts any, outPutPath string) error { - ts, err := GenerateTS(reflect.TypeOf(api), "RPCClient") + ts, err := GenerateTS(reflect.TypeOf(api), reflect.TypeOf(broadcasts), "RPCClient") if err != nil { return err } @@ -27,20 +27,39 @@ func GenClient(api any, outPutPath string) error { return nil } -func GenerateTS(apiType reflect.Type, clientName string) (string, error) { +func GenerateTS(apiType reflect.Type, broadcastType reflect.Type, clientName string) (string, error) { var b strings.Builder - - b.WriteString("// --- AUTO-GENERATED ---\n") - b.WriteString("// Generated by generator. Do not edit by hand (unless you know what you do).\n") - b.WriteString("// Types\n\n") - structs := map[string]reflect.Type{} + b.WriteString("export interface BroadcastEvents {\n") + + if broadcastType != nil { + bt := broadcastType + if bt.Kind() == reflect.Ptr { + bt = bt.Elem() + } + + if bt.Kind() == reflect.Struct { + for i := 0; i < bt.NumField(); i++ { + f := bt.Field(i) + topicName := f.Name + tag := f.Tag.Get("json") + if tag != "" { + topicName = strings.Split(tag, ",")[0] + } + collectStructs(f.Type, structs) + b.WriteString(fmt.Sprintf(" '%s': %s;\n", topicName, goTypeToTS(f.Type))) + } + } + } + b.WriteString("}\n\n") + + b.WriteString("export type BroadcastTopic = keyof BroadcastEvents;\n\n") + for method := range apiType.Methods() { i := -1 for param := range method.Type.Ins() { i++ - //skip self if i == 0 { continue } @@ -48,14 +67,11 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) { } outParams := method.Type.NumOut() - unsupportedMethod := "not supported, allowed layout are \nfunc() (error) \nfunc() (struct, error), \nfunc(x...) (struct, error)" - //func() + if outParams == 0 { return "", errors.New(method.Name + unsupportedMethod) } - - //func() (err) if outParams == 1 { if method.Type.Out(0).Implements(errorType) { collectStructs(method.Type.Out(0), structs) @@ -63,23 +79,18 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) { return "", errors.New(method.Name + unsupportedMethod) } } - - //func() (struct, err) if outParams > 1 && method.Type.Out(1).Implements(errorType) { collectStructs(method.Type.Out(0), structs) } - if outParams > 2 { return "", errors.New(method.Name + unsupportedMethod) } - } names := make([]string, 0, len(structs)) for n := range structs { names = append(names, n) } - slices.Sort(names) for _, name := range names { @@ -106,7 +117,6 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) { b.WriteString("}\n\n") } - b.WriteString("// Generic RPC method result\n") b.WriteString("export type RPCResult = { data: T; error?: any };\n\n") b.WriteString(fmt.Sprintf("export interface %s {\n", clientName)) @@ -141,7 +151,7 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) { if m.Type.NumOut() > 0 { resType = goTypeToTS(m.Type.Out(0)) if m.Type.Out(0).Implements(errorType) { - resType = "any" // It's just a func() error + resType = "any" } } @@ -168,7 +178,6 @@ func GenerateTS(apiType reflect.Type, clientName string) (string, error) { return b.String(), nil } - func collectStructs(t reflect.Type, m map[string]reflect.Type) { if t == nil { return diff --git a/generator_test.go b/generator_test.go index 73c1625..749207e 100644 --- a/generator_test.go +++ b/generator_test.go @@ -1,6 +1,7 @@ package rpc import ( + "fmt" "os" "reflect" "strings" @@ -29,6 +30,12 @@ func (s *MockAPI) GetUser(id int) (User, error) { return User{}, nil } func (s *MockAPI) CreateOrder(o Order) error { return nil } func (s *MockAPI) NoArgs() error { return nil } +type MockBroadcasts struct { + UserUpdated User `json:"user_updated"` + SystemAlert string `json:"system_alert"` + Tick float64 `json:"tick"` +} + func TestGoTypeToTS(t *testing.T) { tests := []struct { name string @@ -75,7 +82,7 @@ func TestCollectStructs(t *testing.T) { func TestGenerateTS_Success(t *testing.T) { api := &MockAPI{} - output, err := GenerateTS(reflect.TypeOf(api), "TestClient") + output, err := GenerateTS(reflect.TypeOf(api), nil, "TestClient") if err != nil { t.Fatalf("GenerateTS failed: %v", err) } @@ -89,6 +96,30 @@ func TestGenerateTS_Success(t *testing.T) { "CreateOrder(order: Order): Promise>;", "useBackend(url: string = '/ws')", } + fmt.Println(output) + + for _, s := range expectedStrings { + if !strings.Contains(output, s) { + t.Errorf("Generated TS missing expected string: %s", s) + } + } +} + +func TestGenerateTS_WithBroadcasts(t *testing.T) { + api := &MockAPI{} + broadcasts := &MockBroadcasts{} + output, err := GenerateTS(reflect.TypeOf(api), reflect.TypeOf(broadcasts), "TestClient") + if err != nil { + t.Fatalf("GenerateTS failed: %v", err) + } + + expectedStrings := []string{ + "export interface BroadcastEvents {", + "'user_updated': User;", + "'system_alert': string;", + "'tick': number;", + "export type BroadcastTopic = keyof BroadcastEvents;", + } for _, s := range expectedStrings { if !strings.Contains(output, s) { @@ -110,7 +141,7 @@ func TestGenerateTS_Validation(t *testing.T) { t.Run("InvalidReturnCount", func(t *testing.T) { api := &APIWithTooManyReturns{} - _, err := GenerateTS(reflect.TypeOf(api), "Client") + _, err := GenerateTS(reflect.TypeOf(api), nil, "Client") if err == nil { t.Error("Expected error for method with 3 return values") } @@ -118,7 +149,7 @@ func TestGenerateTS_Validation(t *testing.T) { t.Run("NoResidentError", func(t *testing.T) { api := &APIWithNoErrorHandler{} - _, err := GenerateTS(reflect.TypeOf(api), "Client") + _, err := GenerateTS(reflect.TypeOf(api), nil, "Client") if err == nil { t.Error("Expected error for method missing 'error' return type") } @@ -130,7 +161,23 @@ func TestGenClient(t *testing.T) { defer os.Remove(tmpFile) api := &MockAPI{} - err := GenClient(api, tmpFile) + err := GenClient(api, nil, tmpFile) + if err != nil { + t.Fatalf("GenClient failed: %v", err) + } + + if _, err := os.Stat(tmpFile); os.IsNotExist(err) { + t.Error("GenClient did not create the file") + } +} + +func TestGenClient_WithBroadcasts(t *testing.T) { + tmpFile := "test_client_broadcast.ts" + defer os.Remove(tmpFile) + + api := &MockAPI{} + broadcasts := &MockBroadcasts{} + err := GenClient(api, broadcasts, tmpFile) if err != nil { t.Fatalf("GenClient failed: %v", err) }