commit 085535931fae51ea752d8777b2037870af100f45 Author: kjuulh Date: Thu Jun 16 22:19:06 2022 +0200 Add base diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/ceen.go b/ceen.go new file mode 100644 index 0000000..55c2ab8 --- /dev/null +++ b/ceen.go @@ -0,0 +1,105 @@ +package ceen + +import ( + "fmt" + "github.com/kjuulh/ceen/codec" + "github.com/kjuulh/ceen/id" + "github.com/kjuulh/ceen/types" + "github.com/nats-io/nats.go" +) + +type Ceen struct { + nc *nats.Conn + js nats.JetStreamContext + types *types.Registry + id id.ID +} + +type ceenOption func(o *Ceen) error + +func (f ceenOption) addOption(o *Ceen) error { + return f(o) +} + +type CeenOption interface { + addOption(o *Ceen) error +} + +func TypeRegistry(types *types.Registry) CeenOption { + return ceenOption(func(o *Ceen) error { + o.types = types + return nil + }) +} + +func (c *Ceen) EventStore(name string) (*EventStore, error) { + return &EventStore{name: name, c: c}, nil +} + +func (c *Ceen) UnpackEvent(msg *nats.Msg) (*Event, error) { + eventType := msg.Header.Get(eventTypeHdr) + codecName := msg.Header.Get(eventCodecHdr) + var ( + data any + err error + ) + + codc, ok := codec.Codecs[codecName] + if !ok { + return nil, fmt.Errorf("%w: %s", codec.ErrCodecNotRegistered, codecName) + } + + if c.types == nil { + var b []byte + err = codc.Unmarshal(msg.Data, &b) + data = b + } else { + var v any + v, err = c.types.Init(eventType) + if err == nil { + err = codc.Unmarshal(msg.Data, v) + data = v + } + } + if err != nil { + return nil, err + } + + var seq uint64 + if msg.Reply != "" { + md, err := msg.Metadata() + if err != nil { + return nil, fmt.Errorf("unpack: failed to get metadata: %s", err) + } + seq = md.Sequence.Stream + } + + return &Event{ + ID: msg.Header.Get(nats.MsgIdHdr), + Type: msg.Header.Get(eventTypeHdr), + Data: data, + Subject: msg.Subject, + Sequence: seq, + }, nil +} + +func New(nc *nats.Conn, options ...CeenOption) (*Ceen, error) { + js, err := nc.JetStream() + if err != nil { + return nil, err + } + + c := &Ceen{ + nc: nc, + js: js, + id: id.NUID, + } + + for _, o := range options { + if err := o.addOption(c); err != nil { + return nil, err + } + } + + return c, nil +} diff --git a/codec/binary.go b/codec/binary.go new file mode 100644 index 0000000..da11aaf --- /dev/null +++ b/codec/binary.go @@ -0,0 +1,47 @@ +package codec + +import ( + "encoding" + "fmt" +) + +var ( + Binary Codec = &binaryCodec{} +) + +type binaryCodec struct{} + +func (*binaryCodec) Name() string { + return "binary" +} + +func (*binaryCodec) Marshal(v interface{}) ([]byte, error) { + // Check for native implementation. + if m, ok := v.(encoding.BinaryMarshaler); ok { + return m.MarshalBinary() + } + + // Otherwise assume byte slice. + b, ok := v.([]byte) + if !ok { + return nil, fmt.Errorf("value not []byte") + } + + return b, nil +} + +func (*binaryCodec) Unmarshal(b []byte, v interface{}) error { + // Check for native implementation. + if u, ok := v.(encoding.BinaryUnmarshaler); ok { + return u.UnmarshalBinary(b) + } + + // Otherwise assume byte slice. + bp, ok := v.(*[]byte) + if !ok { + return fmt.Errorf("value must be *[]byte") + } + + *bp = append((*bp)[:0], b...) + return nil +} diff --git a/codec/codec.go b/codec/codec.go new file mode 100644 index 0000000..acc640e --- /dev/null +++ b/codec/codec.go @@ -0,0 +1,20 @@ +package codec + +import "errors" + +var ( + ErrCodecNotRegistered = errors.New("ceen: codec not registered") + + Default = JSON + + Codecs = map[string]Codec{ + JSON.Name(): JSON, + Binary.Name(): Binary, + } +) + +type Codec interface { + Name() string + Marshal(any) ([]byte, error) + Unmarshal([]byte, any) error +} diff --git a/codec/json.go b/codec/json.go new file mode 100644 index 0000000..5cf060f --- /dev/null +++ b/codec/json.go @@ -0,0 +1,24 @@ +package codec + +import "encoding/json" + +var ( + JSON Codec = &jsonCodec{} +) + +type jsonCodec struct{} + +func (*jsonCodec) Name() string { + return "json" +} + +func (*jsonCodec) Marshal(v any) ([]byte, error) { + return json.Marshal(v) + +} +func (*jsonCodec) Unmarshal(b []byte, v any) error { + if len(b) == 0 { + return nil + } + return json.Unmarshal(b, v) +} diff --git a/event.go b/event.go new file mode 100644 index 0000000..3ed9714 --- /dev/null +++ b/event.go @@ -0,0 +1,11 @@ +package ceen + +type Event struct { + ID string + + Type string + + Data any + Subject string + Sequence uint64 +} diff --git a/event_store.go b/event_store.go new file mode 100644 index 0000000..d495614 --- /dev/null +++ b/event_store.go @@ -0,0 +1,227 @@ +package ceen + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/kjuulh/ceen/codec" + "github.com/nats-io/nats.go" + "strings" +) + +const ( + eventTypeHdr = "ceen-type" + eventCodecHdr = "ceen-codec" +) + +type EventStore struct { + c *Ceen + name string +} + +func (es *EventStore) Create(conf *nats.StreamConfig) error { + if conf == nil { + conf = &nats.StreamConfig{} + } + conf.Name = es.name + + if len(conf.Subjects) == 0 { + conf.Subjects = []string{fmt.Sprintf("%s.>", es.name)} + } + _, err := es.c.js.AddStream(conf) + return err +} + +func (es *EventStore) Append(ctx context.Context, subject string, events ...*Event) (uint64, error) { + var ack *nats.PubAck + + for _, event := range events { + popts := []nats.PubOpt{ + nats.Context(ctx), + nats.ExpectStream(es.name), + } + + e, err := es.wrapEvent(event) + if err != nil { + return 0, err + } + + msg, err := es.packEvent(subject, e) + if err != nil { + return 0, err + } + + ack, err = es.c.js.PublishMsg(msg, popts...) + if err != nil { + if strings.Contains(err.Error(), "wrong last sequence") { + return 0, errors.New("wrong last sequence") + } + return 0, err + } + } + + return ack.Sequence, nil +} + +func (es *EventStore) wrapEvent(event *Event) (*Event, error) { + if event.Data == nil { + return nil, errors.New("event data is required") + } + + if es.c.types == nil { + if event.Type == "" { + return nil, errors.New("event type is required") + } + } else { + t, err := es.c.types.Lookup(event.Data) + if err != nil { + return nil, err + } + + if event.Type == "" { + event.Type = t + } else if event.Type != t { + return nil, fmt.Errorf("wrong type for event data: %s", event.Data) + } + } + if v, ok := event.Data.(validator); ok { + if err := v.Validate(); err != nil { + return nil, err + } + } + if event.ID == "" { + event.ID = es.c.id.New() + } + + return event, nil +} + +func (es *EventStore) packEvent(subject string, event *Event) (*nats.Msg, error) { + var ( + data []byte + err error + codecName string + ) + + if es.c.types == nil { + data, err = codec.Binary.Marshal(event.Data) + codecName = codec.Binary.Name() + } else { + data, err = es.c.types.Marshal(event.Data) + codecName = es.c.types.Codec().Name() + } + if err != nil { + return nil, err + } + + msg := nats.NewMsg(subject) + msg.Data = data + + msg.Header.Set(nats.MsgIdHdr, event.ID) + msg.Header.Set(eventTypeHdr, event.Type) + msg.Header.Set(eventCodecHdr, codecName) + + return msg, nil +} + +type loadOpts struct { + afterSeq *uint64 +} + +type natsApiError struct { + Code int `json:"code"` + ErrCode uint16 `json:"err_code"` + Description string `json:"description"` +} + +type natsStoredMsg struct { + Sequence uint64 `json:"seq"` +} + +type natsGetMsgRequest struct { + LastBySubject string `json:"last_by_subj"` +} + +type natsGetMsgResponse struct { + Type string `json:"type"` + Error *natsApiError `json:"error"` + Message *natsStoredMsg `json:"message"` +} + +func (es *EventStore) Load(ctx context.Context, subject string) ([]*Event, uint64, error) { + lastMsg, err := es.lastMsgForSubject(ctx, subject) + if err != nil { + return nil, 0, err + } + + if lastMsg.Sequence == 0 { + return nil, 0, nil + } + + sopts := []nats.SubOpt{ + nats.OrderedConsumer(), + nats.DeliverAll(), + } + + sub, err := es.c.js.SubscribeSync(subject, sopts...) + if err != nil { + return nil, 0, err + } + + defer sub.Unsubscribe() + + var events []*Event + + for { + msg, err := sub.NextMsgWithContext(ctx) + if err != nil { + return nil, 0, err + } + + event, err := es.c.UnpackEvent(msg) + if err != nil { + return nil, 0, err + } + + events = append(events, event) + if event.Sequence == lastMsg.Sequence { + break + } + } + + return events, lastMsg.Sequence, nil +} + +func (es *EventStore) lastMsgForSubject(ctx context.Context, subject string) (*natsStoredMsg, error) { + rsubject := fmt.Sprintf("$JS.API.STREAM.MSG.GET.%s", es.name) + + data, _ := json.Marshal(&natsGetMsgRequest{ + LastBySubject: subject, + }) + + msg, err := es.c.nc.RequestWithContext(ctx, rsubject, data) + if err != nil { + return nil, err + } + + var rep natsGetMsgResponse + err = json.Unmarshal(msg.Data, &rep) + if err != nil { + return nil, err + } + + if rep.Error != nil { + if rep.Error.Code == 404 { + return &natsStoredMsg{}, nil + } + + return nil, fmt.Errorf("%s (%d)", rep.Error.Description, rep.Error.Code) + } + + return rep.Message, nil +} + +func (es *EventStore) Delete() error { + return es.c.js.DeleteStream(es.name) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f101d34 --- /dev/null +++ b/go.mod @@ -0,0 +1,26 @@ +module github.com/kjuulh/ceen + +go 1.18 + +require ( + github.com/nats-io/nats-server/v2 v2.8.4 + github.com/nats-io/nats.go v1.16.0 +) + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/klauspost/compress v1.14.4 // indirect + github.com/minio/highwayhash v1.0.2 // indirect + github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a // indirect + github.com/nats-io/nkeys v0.3.0 // indirect + github.com/nats-io/nuid v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.1.0 // indirect + github.com/stretchr/testify v1.7.2 // indirect + golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect + golang.org/x/sys v0.0.0-20220111092808-5a964db01320 // indirect + golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect + google.golang.org/protobuf v1.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..073845e --- /dev/null +++ b/go.sum @@ -0,0 +1,47 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/klauspost/compress v1.14.4 h1:eijASRJcobkVtSt81Olfh7JX43osYLwy5krOJo6YEu4= +github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= +github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= +github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a h1:lem6QCvxR0Y28gth9P+wV2K/zYUUAkJ+55U8cpS0p5I= +github.com/nats-io/jwt/v2 v2.2.1-0.20220330180145-442af02fd36a/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= +github.com/nats-io/nats-server/v2 v2.8.4 h1:0jQzze1T9mECg8YZEl8+WYUXb9JKluJfCBriPUtluB4= +github.com/nats-io/nats-server/v2 v2.8.4/go.mod h1:8zZa+Al3WsESfmgSs98Fi06dRWLH5Bnq90m5bKD/eT4= +github.com/nats-io/nats.go v1.16.0 h1:zvLE7fGBQYW6MWaFaRdsgm9qT39PJDQoju+DS8KsO1g= +github.com/nats-io/nats.go v1.16.0/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= +github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= +github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= +github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= +golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= +golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220111092808-5a964db01320 h1:0jf+tOCoZ3LyutmCOWpVni1chK4VfFLhRsDK7MhqGRY= +golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/id/id.go b/id/id.go new file mode 100644 index 0000000..10d4f70 --- /dev/null +++ b/id/id.go @@ -0,0 +1,17 @@ +package id + +import "github.com/nats-io/nuid" + +var ( + NUID ID = &nuidGen{} +) + +type ID interface { + New() string +} + +type nuidGen struct{} + +func (i *nuidGen) New() string { + return nuid.Next() +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..e3f2d49 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,27 @@ +package testutil + +import ( + "os" + + "github.com/nats-io/nats-server/v2/server" + natsserver "github.com/nats-io/nats-server/v2/test" +) + +func NewNatsServer(port int) *server.Server { + opts := natsserver.DefaultTestOptions + opts.Port = port + opts.JetStream = true + return natsserver.RunServer(&opts) +} + +func ShutdownNatsServer(s *server.Server) { + var sd string + if config := s.JetStreamConfig(); config != nil { + sd = config.StoreDir + } + s.Shutdown() + if sd != "" { + os.RemoveAll(sd) + } + s.WaitForShutdown() +} diff --git a/lib_test.go b/lib_test.go new file mode 100644 index 0000000..1aac1e3 --- /dev/null +++ b/lib_test.go @@ -0,0 +1,105 @@ +package ceen + +import ( + "context" + "fmt" + "github.com/kjuulh/ceen/internal/testutil" + "github.com/kjuulh/ceen/types" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/require" + "testing" +) + +func TestLib(t *testing.T) { + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, _ := nats.Connect(srv.ClientURL()) + + c, err := New(nc) + require.NoError(t, err) + + testItems, err := c.EventStore("test_items") + + err = testItems.Create(&nats.StreamConfig{ + Storage: nats.MemoryStorage, + }) + require.NoError(t, err) + + ctx := context.Background() + + seq, err := testItems.Append(ctx, "test_items.1", &Event{Type: "test_item", Data: []byte("first-item")}) + require.NoError(t, err) + require.Equal(t, uint64(1), seq) + + events, _, err := testItems.Load(ctx, "test_items.1") + require.NoError(t, err) + require.Equal(t, "test_item", events[0].Type) + require.Equal(t, any([]byte("first-item")), events[0].Data) +} + +type TestEventCreated struct { + ID string +} + +func TestLibWithRegistry(t *testing.T) { + tests := []struct { + Name string + Run func(t *testing.T, es *EventStore, subject string) + }{ + { + Name: "append-load-no-occ", + Run: func(t *testing.T, es *EventStore, subject string) { + ctx := context.Background() + testEvent := TestEventCreated{ID: "some-event-id"} + seq, err := es.Append(ctx, subject, &Event{Data: &testEvent}) + require.NoError(t, err) + require.Equal(t, uint64(1), seq) + + events, lseq, err := es.Load(ctx, subject) + require.NoError(t, err) + require.Equal(t, seq, lseq) + + require.True(t, events[0].ID != "") + require.Equal(t, "test-event-created", events[0].Type) + + data, ok := events[0].Data.(*TestEventCreated) + require.True(t, ok) + require.Equal(t, testEvent, *data) + }, + }, + } + + srv := testutil.NewNatsServer(-1) + defer testutil.ShutdownNatsServer(srv) + + nc, err := nats.Connect(srv.ClientURL()) + require.NoError(t, err) + + tr, err := types.NewRegistry(map[string]*types.Type{ + "test-event-created": { + Init: func() any { + return &TestEventCreated{} + }, + }, + }) + require.NoError(t, err) + + c, err := New(nc, TypeRegistry(tr)) + require.NoError(t, err) + + for i, test := range tests { + t.Run(test.Name, func(t *testing.T) { + es, err := c.EventStore("testevents") + require.NoError(t, err) + + _ = es.Delete() + err = es.Create(&nats.StreamConfig{Storage: nats.MemoryStorage}) + require.NoError(t, err) + + subject := fmt.Sprintf("testevents.%d", i) + + test.Run(t, es, subject) + }) + } +} diff --git a/types/registry.go b/types/registry.go new file mode 100644 index 0000000..dcc4642 --- /dev/null +++ b/types/registry.go @@ -0,0 +1,143 @@ +package types + +import ( + "errors" + "fmt" + "github.com/kjuulh/ceen/codec" + "reflect" + "regexp" +) + +var ( + ErrTypeNotValid = errors.New("ceen: type not valid") + ErrTypeNotRegistered = errors.New("ceen: type not registered") + ErrNoTypeForStruct = errors.New("ceen: no type for struct") + ErrMarshal = errors.New("ceen: marshal error") + ErrUnmarshal = errors.New("ceen: unmarshal error") + + nameRegex = regexp.MustCompile(`^[\w-]+(\.[\w-]+)*$`) +) + +type Type struct { + Init func() any +} + +type Registry struct { + rtypes map[reflect.Type]string + types map[string]*Type + codec codec.Codec +} + +func (r *Registry) Lookup(v any) (string, error) { + ref := reflect.TypeOf(v) + t, ok := r.rtypes[ref] + if !ok { + return "", fmt.Errorf("%w: %s", errors.New("no type for struct"), ref) + } + + return t, nil +} + +func (r *Registry) Init(eventType string) (any, error) { + t, ok := r.types[eventType] + if !ok { + return nil, fmt.Errorf("%w: %s", ErrTypeNotRegistered, eventType) + } + + v := t.Init() + + return v, nil +} + +func (r *Registry) validate(name string, typeDec *Type) error { + if name == "" { + return fmt.Errorf("%w: missing name", ErrTypeNotValid) + } + + if err := validateTypeName(name); err != nil { + return err + } + + if typeDec.Init == nil { + return fmt.Errorf("%w: %s", ErrTypeNotValid, name) + } + + v := typeDec.Init() + if v == nil { + return fmt.Errorf("%w: %s: init func returns nil", ErrTypeNotValid, name) + } + + rt := reflect.TypeOf(v) + + if rt.Kind() != reflect.Ptr { + return fmt.Errorf("%w: %s: init func must return a pointer value", ErrTypeNotValid, name) + } + + if rt.Elem().Kind() != reflect.Struct { + return fmt.Errorf("%w: %s", ErrTypeNotValid, name) + } + + b, err := r.codec.Marshal(v) + if err != nil { + return fmt.Errorf("%w: %s: failed to marshal with codec: %s", ErrTypeNotValid, name, err) + } + + err = r.codec.Unmarshal(b, v) + if err != nil { + return fmt.Errorf("%w: %s: failed to unmarshal with codec: %s", ErrTypeNotValid, name, err) + } + + return nil +} + +func validateTypeName(name string) error { + if !nameRegex.MatchString(name) { + return fmt.Errorf("%w: name %q has invalid characters", ErrTypeNotValid, name) + } + return nil +} + +func (r *Registry) addType(name string, typeDec *Type) { + r.types[name] = typeDec + + v := typeDec.Init() + rt := reflect.TypeOf(v) + + r.rtypes[rt] = name + r.rtypes[rt.Elem()] = name +} + +func (r *Registry) Marshal(data any) ([]byte, error) { + _, err := r.Lookup(data) + if err != nil { + return nil, err + } + + b, err := r.codec.Marshal(data) + if err != nil { + return b, fmt.Errorf("%T, marshal error: %w", data, err) + } + return b, nil +} + +func (r *Registry) Codec() codec.Codec { + return r.codec +} + +func NewRegistry(typeDecs map[string]*Type) (*Registry, error) { + r := &Registry{ + rtypes: make(map[reflect.Type]string), + types: make(map[string]*Type), + codec: codec.Default, + } + + for n, typeDec := range typeDecs { + err := r.validate(n, typeDec) + if err != nil { + return nil, err + } + r.addType(n, typeDec) + } + + return r, nil +} diff --git a/validator.go b/validator.go new file mode 100644 index 0000000..24f3483 --- /dev/null +++ b/validator.go @@ -0,0 +1,5 @@ +package ceen + +type validator interface { + Validate() error +}