diff --git a/upstream/upstream-http/upstream.go b/upstream/upstream-http/upstream.go index 59704160..e3511f93 100644 --- a/upstream/upstream-http/upstream.go +++ b/upstream/upstream-http/upstream.go @@ -23,14 +23,14 @@ import ( //Http org type httpUpstream struct { - id string - name string - driver string - desc string - scheme string - balanceType string - app discovery.IApp - balanceFactory balance.IBalanceFactory + id string + name string + driver string + desc string + scheme string + balanceType string + app discovery.IApp + handler balance.IBalanceHandler } func (h *httpUpstream) Id() string { @@ -61,7 +61,11 @@ func (h *httpUpstream) Reset(conf interface{}, workers map[eosc.RequireId]interf if err != nil { return err } - h.balanceFactory, err = balance.GetFactory(h.balanceType) + f, err := balance.GetFactory(h.balanceType) + if err != nil { + return err + } + h.handler, err = f.Create(h.app) if err != nil { return err } @@ -82,14 +86,11 @@ func (h *httpUpstream) CheckSkill(skill string) bool { //send 请求发送,忽略重试 func (h *httpUpstream) Send(ctx *http_context.Context, serviceDetail service.IServiceDetail) (*http.Response, error) { - handler, err := h.balanceFactory.Create(h.app) - if err != nil { - return nil, err - } + var response *http.Response path := utils.TrimPrefixAll(ctx.ProxyRequest.TargetURL(), "/") - node, err := handler.Next() + node, err := h.handler.Next() if err != nil { return nil, err } @@ -103,7 +104,7 @@ func (h *httpUpstream) Send(ctx *http_context.Context, serviceDetail service.ISe node.Down() } h.app.NodeError(node.ID()) - node, err = handler.Next() + node, err = h.handler.Next() if err != nil { return nil, err } diff --git a/upstream/upstream-http/upstream_test.go b/upstream/upstream-http/upstream_test.go index 1b96dcc0..41c8d93b 100644 --- a/upstream/upstream-http/upstream_test.go +++ b/upstream/upstream-http/upstream_test.go @@ -55,6 +55,15 @@ func (s *Service) ProxyAddr() string { return s.ProxyAddr() } +func getWorker(factory eosc.IProfessionDriverFactory, cfg interface{}, profession string, name string, label string, desc string, params map[string]string, workerID, workerName string, worker map[eosc.RequireId]interface{}) (eosc.IWorker, error) { + driver, err := factory.Create(profession, name, label, desc, params) + if err != nil { + return nil, err + } + + return driver.Create(workerID, workerName, cfg, worker) +} + func TestSend(t *testing.T) { round_robin.Register() s := &Service{ @@ -66,60 +75,43 @@ func TestSend(t *testing.T) { } factory := NewFactory() t.Log("upstream extend info:", factory.ExtendInfo()) - driver, err := factory.Create("upstream", "http_proxy", "", "http转发驱动", nil) - if err != nil { - t.Error(err) - return - } - cfg := &Config{ + + staticConfig := &Config{ Name: "product-user", Driver: "http_proxy", Desc: "生产环境-用户模块", Scheme: "http", Type: "round-robin", - Config: "127.0.0.1:8580 weight=10;47.95.203.198:8080 weight=15", + Config: "127.0.0.1:8580 weight=10;47.95.203.198:8080 weight=30", Discovery: "static_1@discovery", } - staticDiscovery := discovery_static.NewFactory() - t.Log("static discovery extend info:", staticDiscovery.ExtendInfo()) - staticDriver, err := staticDiscovery.Create("discovery", "static", "", "静态服务发现驱动", nil) - if err != nil { - t.Error(err) - return - } - staticCfg := &discovery_static.Config{ + staticWorker, err := getWorker(discovery_static.NewFactory(), &discovery_static.Config{ Name: "static_1", Driver: "static", Labels: nil, Health: &discovery_static.HealthConfig{ Protocol: "http", - Method: "GET", - URL: "/", - SuccessCode: 404, + Method: "POST", + URL: "/Web/Test/params/print", + SuccessCode: 200, Period: 30, Timeout: 3000, }, HealthOn: true, + }, "discovery", "static", "", "静态服务发现", nil, "", "static_1", nil) + if err != nil { + t.Error(err) + return } - staticWorker, err := staticDriver.Create("", "static_1", staticCfg, nil) + allWorker := make(map[eosc.RequireId]interface{}) + allWorker["static_1@discovery"] = staticWorker + worker, err := getWorker(NewFactory(), staticConfig, "upstream", "http_proxy", "", "http转发驱动", nil, "", "product-user", allWorker) if err != nil { t.Error(err) return } - worker, err := driver.Create( - "", - "product-user", - cfg, - map[eosc.RequireId]interface{}{ - "static_1@discovery": staticWorker, - }) - if err != nil { - t.Error(err) - return - } - worker.Start() hUpstream, ok := worker.(upstream.IUpstream) if !ok { t.Error(ErrorStructType)