diff --git a/pub/fed.go b/pub/fed.go index 8de8f0b..41298f5 100644 --- a/pub/fed.go +++ b/pub/fed.go @@ -14,11 +14,10 @@ import ( ) var ( - // TODO: Just respond with an HTTP error, don't put onus on Application. - // ErrObjectRequired means the activity needs its object property set. - ErrObjectRequired = errors.New("object property required") - // ErrTargetRequired means the activity needs its target property set. - ErrTargetRequired = errors.New("target property required") + // errObjectRequired means the activity needs its object property set. + errObjectRequired = errors.New("object property required") + // errTargetRequired means the activity needs its target property set. + errTargetRequired = errors.New("target property required") ) // TODO: Helper for sending arbitrary ActivityPub objects. @@ -33,8 +32,6 @@ type Pubber interface { // application specific behavior. The handler will return true if it handled // the request as an ActivityPub request. If it returns an error, it is up to // the client to determine how to respond via HTTP. - // - // Note that the error could be ErrObjectRequired or ErrTargetRequired. PostOutbox(c context.Context, w http.ResponseWriter, r *http.Request) (bool, error) GetOutbox(c context.Context, w http.ResponseWriter, r *http.Request) (bool, error) } @@ -186,6 +183,10 @@ func (f *federator) PostInbox(c context.Context, w http.ResponseWriter, r *http. return true, err } if err = f.getPostInboxResolver(c, *r.URL).Deserialize(m); err != nil { + if err == errObjectRequired || err == errTargetRequired { + w.WriteHeader(http.StatusBadRequest) + return true, nil + } return true, err } if err := f.addToInbox(c, r, m); err != nil { @@ -301,6 +302,10 @@ func (f *federator) PostOutbox(c context.Context, w http.ResponseWriter, r *http } deliverable := false if err = f.getPostOutboxResolver(c, m, &deliverable, &m).Deserialize(m); err != nil { + if err == errObjectRequired || err == errTargetRequired { + w.WriteHeader(http.StatusBadRequest) + return true, nil + } return true, err } if err := f.addToOutbox(c, r, m); err != nil { @@ -369,7 +374,7 @@ func (f *federator) handleClientCreate(ctx context.Context, deliverable *bool, t return func(s *streams.Create) error { *deliverable = true if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } c := s.Raw() // When a Create activity is posted, the actor of the activity @@ -472,7 +477,7 @@ func (f *federator) handleClientUpdate(c context.Context, rawJson map[string]int return func(s *streams.Update) error { *deliverable = true if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } // Update should partially replace the 'object' with only the // changed top-level fields. @@ -522,7 +527,7 @@ func (f *federator) handleClientDelete(c context.Context, deliverable *bool) fun return func(s *streams.Delete) error { *deliverable = true if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } ids, err := getObjectIds(s.Raw()) if err != nil { @@ -552,7 +557,7 @@ func (f *federator) handleClientFollow(c context.Context, deliverable *bool) fun return func(s *streams.Follow) error { *deliverable = true if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } return f.ClientCallbacker.Follow(c, s) } @@ -576,9 +581,9 @@ func (f *federator) handleClientAdd(c context.Context, deliverable *bool) func(s return func(s *streams.Add) error { *deliverable = true if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } else if s.LenTarget() == 0 { - return ErrTargetRequired + return errTargetRequired } raw := s.Raw() ids, err := getTargetIds(raw) @@ -640,9 +645,9 @@ func (f *federator) handleClientRemove(c context.Context, deliverable *bool) fun return func(s *streams.Remove) error { *deliverable = true if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } else if s.LenTarget() == 0 { - return ErrTargetRequired + return errTargetRequired } raw := s.Raw() ids, err := getTargetIds(raw) @@ -704,7 +709,7 @@ func (f *federator) handleClientLike(ctx context.Context, deliverable *bool) fun return func(s *streams.Like) error { *deliverable = true if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } getter := func(actor vocab.ObjectType, lc *vocab.CollectionType, loc *vocab.OrderedCollectionType) (bool, error) { if actor.IsLikedAnyURI() { @@ -739,7 +744,7 @@ func (f *federator) handleClientUndo(c context.Context, deliverable *bool) func( return func(s *streams.Undo) error { *deliverable = true if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } // TODO: Determine if we can support common forms of undo natively. return f.ClientCallbacker.Undo(c, s) @@ -750,7 +755,7 @@ func (f *federator) handleClientBlock(c context.Context, deliverable *bool) func return func(s *streams.Block) error { *deliverable = false if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } return f.ClientCallbacker.Block(c, s) } @@ -778,7 +783,7 @@ func (f *federator) handleCreate(c context.Context) func(s *streams.Create) erro // Create requires the client application to persist the 'object' that // was created. if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } raw := s.Raw() for i := 0; i < raw.ObjectLen(); i++ { @@ -798,7 +803,7 @@ func (f *federator) handleCreate(c context.Context) func(s *streams.Create) erro func (f *federator) handleUpdate(c context.Context) func(s *streams.Update) error { return func(s *streams.Update) error { if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } raw := s.Raw() if err := f.ensureActivityOriginMatchesObjects(raw); err != nil { @@ -821,7 +826,7 @@ func (f *federator) handleUpdate(c context.Context) func(s *streams.Update) erro func (f *federator) handleDelete(c context.Context) func(s *streams.Delete) error { return func(s *streams.Delete) error { if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } raw := s.Raw() if err := f.ensureActivityOriginMatchesObjects(raw); err != nil { @@ -856,7 +861,7 @@ func (f *federator) handleFollow(c context.Context, inboxURL url.URL) func(s *st // Permit either human-triggered or automatically triggering // 'Accept'/'Reject'. if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } todo := f.FederateApp.OnFollow(c, s) if todo != DoNothing { @@ -981,9 +986,9 @@ func (f *federator) handleAdd(c context.Context) func(s *streams.Add) error { // Add is client application specific, generally involving adding an // 'object' to a specific 'target' collection. if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } else if s.LenTarget() == 0 { - return ErrTargetRequired + return errTargetRequired } raw := s.Raw() ids, err := getTargetIds(raw) @@ -1045,9 +1050,9 @@ func (f *federator) handleRemove(c context.Context) func(s *streams.Remove) erro // Remove is client application specific, generally involving removing // an 'object' from a specific 'target' collection. if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } else if s.LenTarget() == 0 { - return ErrTargetRequired + return errTargetRequired } raw := s.Raw() ids, err := getTargetIds(raw) @@ -1107,7 +1112,7 @@ func (f *federator) handleRemove(c context.Context) func(s *streams.Remove) erro func (f *federator) handleLike(c context.Context) func(s *streams.Like) error { return func(s *streams.Like) error { if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } getter := func(object vocab.ObjectType, lc *vocab.CollectionType, loc *vocab.OrderedCollectionType) (bool, error) { if object.IsLikesAnyURI() { @@ -1145,7 +1150,7 @@ func (f *federator) handleUndo(c context.Context) func(s *streams.Undo) error { // Here we enforce that the actors on the Undo must correspond // to all objects' original actors in some manner. if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } raw := s.Raw() if err := f.ensureActivityActorsMatchObjectActors(raw); err != nil { @@ -1161,7 +1166,7 @@ func (f *federator) handleBlock(c context.Context) func(s *streams.Block) error // were to accept it. return func(s *streams.Block) error { if s.LenObject() == 0 { - return ErrObjectRequired + return errObjectRequired } return nil } diff --git a/pub/fed_test.go b/pub/fed_test.go index 9151dee..f59655a 100644 --- a/pub/fed_test.go +++ b/pub/fed_test.go @@ -1945,8 +1945,10 @@ func TestPostInbox_RequiresObject(t *testing.T) { resp := httptest.NewRecorder() req := ActivityPubRequest(httptest.NewRequest("POST", testInboxURI, bytes.NewBuffer(MustSerialize(test.input())))) handled, err := p.PostInbox(context.Background(), resp, req) - if err != ErrObjectRequired { - t.Fatalf("(%s) expected %s, got %s", test.name, ErrObjectRequired, err) + if resp.Code != http.StatusBadRequest { + t.Fatalf("(%s) expected %d, got %d", test.name, http.StatusBadRequest, resp.Code) + } else if err != nil { + t.Fatal(err) } else if !handled { t.Fatalf("(%s) expected handled, got !handled", test.name) } @@ -1990,8 +1992,10 @@ func TestPostInbox_RequiresTarget(t *testing.T) { resp := httptest.NewRecorder() req := ActivityPubRequest(httptest.NewRequest("POST", testInboxURI, bytes.NewBuffer(MustSerialize(test.input())))) handled, err := p.PostInbox(context.Background(), resp, req) - if err != ErrTargetRequired { - t.Fatalf("(%s) expected %s, got %s", test.name, ErrTargetRequired, err) + if resp.Code != http.StatusBadRequest { + t.Fatalf("(%s) expected %d, got %d", test.name, http.StatusBadRequest, resp.Code) + } else if err != nil { + t.Fatal(err) } else if !handled { t.Fatalf("(%s) expected handled, got !handled", test.name) } @@ -4327,8 +4331,10 @@ func TestPostOutbox_RequiresObject(t *testing.T) { resp := httptest.NewRecorder() req := Sign(ActivityPubRequest(httptest.NewRequest("POST", testOutboxURI, bytes.NewBuffer(MustSerialize(test.input()))))) handled, err := p.PostOutbox(context.Background(), resp, req) - if err != ErrObjectRequired { - t.Fatalf("(%s) expected %s, got %s", test.name, ErrObjectRequired, err) + if resp.Code != http.StatusBadRequest { + t.Fatalf("(%s) expected %d, got %d", test.name, http.StatusBadRequest, resp.Code) + } else if err != nil { + t.Fatal(err) } else if !handled { t.Fatalf("(%s) expected handled, got !handled", test.name) } @@ -4370,8 +4376,10 @@ func TestPostOutbox_RequiresTarget(t *testing.T) { resp := httptest.NewRecorder() req := Sign(ActivityPubRequest(httptest.NewRequest("POST", testOutboxURI, bytes.NewBuffer(MustSerialize(test.input()))))) handled, err := p.PostOutbox(context.Background(), resp, req) - if err != ErrTargetRequired { - t.Fatalf("(%s) expected %s, got %s", test.name, ErrTargetRequired, err) + if resp.Code != http.StatusBadRequest { + t.Fatalf("(%s) expected %d, got %d", test.name, http.StatusBadRequest, resp.Code) + } else if err != nil { + t.Fatal(err) } else if !handled { t.Fatalf("(%s) expected handled, got !handled", test.name) }