summary refs log tree commit diff
diff options
context:
space:
mode:
authorTed Unangst <tedu@tedunangst.com>2022-03-25 23:45:03 -0400
committerTed Unangst <tedu@tedunangst.com>2022-03-25 23:45:03 -0400
commit8460b97bff37c577973d5c3f8e023746a67fc0f8 (patch)
tree8355675607689f0b0935f2f662278213689365b2
parent2163fff7274004d098815aafc9084b04215f1af7 (diff)
better filters
-rw-r--r--go.mod2
-rw-r--r--interpreter.go51
-rw-r--r--miniwebproxy.go7
3 files changed, 40 insertions, 20 deletions
diff --git a/go.mod b/go.mod
index c5268e1..4e5e6b3 100644
--- a/go.mod
+++ b/go.mod
@@ -7,3 +7,5 @@ require (
 	github.com/traefik/yaegi v0.11.2
 	golang.org/x/net v0.0.0-20190415214537-1da14a5a36f2
 )
+
+replace github.com/traefik/yaegi => ../yaegi
diff --git a/interpreter.go b/interpreter.go
index 825f3b3..8042d4f 100644
--- a/interpreter.go
+++ b/interpreter.go
@@ -5,11 +5,12 @@ import (
 	"log"
 	"net/http"
 	"os"
-	"runtime"
 	"reflect"
+	"runtime"
 
-	"github.com/traefik/yaegi/stdlib"
+	"github.com/andybalholm/cascadia"
 	"github.com/traefik/yaegi/interp"
+	"github.com/traefik/yaegi/stdlib"
 	"golang.org/x/net/html"
 )
 
@@ -17,7 +18,7 @@ type interpreter struct {
 	interp          *interp.Interpreter
 	newrequest      func(*http.Request)
 	shouldintercept func(string) bool
-	prefilter       func(*http.Request)
+	prefilter       func(*http.Request) *http.Response
 	filterresponse  func(io.Writer, *http.Request, *http.Response) bool
 	filterhtml      func(io.Writer, *http.Request, *html.Node) bool
 }
@@ -26,8 +27,8 @@ func (interp *interpreter) ShouldIntercept(hostname string) bool {
 	return interp.shouldintercept(hostname)
 }
 
-func (interp *interpreter) Prefilter(req *http.Request) {
-	interp.prefilter(req)
+func (interp *interpreter) Prefilter(req *http.Request) *http.Response {
+	return interp.prefilter(req)
 }
 
 func (interp *interpreter) FilterResponse(w io.Writer, req *http.Request, resp *http.Response) bool {
@@ -38,12 +39,11 @@ func (interp *interpreter) FilterHTML(w io.Writer, req *http.Request, root *html
 	return interp.filterhtml(w, req, root)
 }
 
-func defNewRequest(*http.Request) {
-}
 func defShouldIntercept(string) bool {
 	return true
 }
-func defPrefilter(*http.Request) {
+func defPrefilter(*http.Request) *http.Response {
+	return nil
 }
 func defFilterResponse(io.Writer, *http.Request, *http.Response) bool {
 	return false
@@ -55,28 +55,44 @@ func defFilterHTML(io.Writer, *http.Request, *html.Node) bool {
 func getinterpreter() *interpreter {
 	var err error
 	i := new(interpreter)
+	i.shouldintercept = defShouldIntercept
+	i.prefilter = defPrefilter
+	i.filterresponse = defFilterResponse
+	i.filterhtml = defFilterHTML
+	rawsrc, err := os.ReadFile("scripts/script.go")
+	if err != nil {
+		log.Printf("err load src: %s", err)
+		return i
+	}
+
 	i.interp = interp.New(interp.Options{GoPath: runtime.GOROOT()})
 	i.interp.Use(stdlib.Symbols)
 	exports := make(interp.Exports)
 	exports["html/html"] = map[string]reflect.Value{
-		"Node": reflect.ValueOf((*html.Node)(nil)),
+		"Node":        reflect.ValueOf((*html.Node)(nil)),
+		"Parse":       reflect.ValueOf(html.Parse),
+		"Render":      reflect.ValueOf(html.Render),
+		"ElementNode": reflect.ValueOf(html.ElementNode),
+		"TextNode":    reflect.ValueOf(html.TextNode),
+	}
+	exports["cascadia/cascadia"] = map[string]reflect.Value{
+		"MustCompile": reflect.ValueOf(cascadia.MustCompile),
+	}
+	exports["filttools/filttools"] = map[string]reflect.Value{
+		"Clean": reflect.ValueOf(clean),
 	}
 	i.interp.Use(exports)
 	i.interp.ImportUsed()
-	rawsrc, err := os.ReadFile("scripts/script.go")
-	if err != nil {
-		log.Panicf("err load src: %s", err)
-	}
 	src := string(rawsrc)
 	_, err = i.interp.Eval(src)
 	if err != nil {
-		log.Panicf("err eval src: %s", err)
+		log.Printf("err eval src: %s", err)
+		return i
 	}
 
 	si, err := i.interp.Eval("filt.ShouldIntercept")
 	if err != nil {
 		log.Printf("err eval shouldintercept: %s", err)
-		i.shouldintercept = defShouldIntercept
 	} else {
 		i.shouldintercept = si.Interface().(func(string) bool)
 	}
@@ -84,15 +100,13 @@ func getinterpreter() *interpreter {
 	pf, err := i.interp.Eval("filt.Prefilter")
 	if err != nil {
 		log.Printf("err eval prefilter: %s", err)
-		i.prefilter = defPrefilter
 	} else {
-		i.prefilter = pf.Interface().(func(*http.Request))
+		i.prefilter = pf.Interface().(func(*http.Request) *http.Response)
 	}
 
 	fr, err := i.interp.Eval("filt.FilterResponse")
 	if err != nil {
 		log.Printf("err eval filterresponse: %s", err)
-		i.filterresponse = defFilterResponse
 	} else {
 		i.filterresponse = fr.Interface().(func(io.Writer, *http.Request, *http.Response) bool)
 	}
@@ -100,7 +114,6 @@ func getinterpreter() *interpreter {
 	fh, err := i.interp.Eval("filt.FilterHTML")
 	if err != nil {
 		log.Printf("err eval filterhtml: %s", err)
-		i.filterhtml = defFilterHTML
 	} else {
 		i.filterhtml = fh.Interface().(func(io.Writer, *http.Request, *html.Node) bool)
 	}
diff --git a/miniwebproxy.go b/miniwebproxy.go
index 4adbd65..8bcf46a 100644
--- a/miniwebproxy.go
+++ b/miniwebproxy.go
@@ -323,7 +323,12 @@ func (pxr *Proxer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			clientreq.Header.Set("Accept-Encoding", "gzip")
 		}
 		log.Printf("clientreq url %s/%s\n", clientreq.URL.Hostname(), clientreq.URL.Path)
-		interp.Prefilter(clientreq)
+		filtresp := interp.Prefilter(clientreq)
+		if filtresp != nil {
+			log.Printf("filter responded fast")
+			filtresp.Write(clientconn)
+			return
+		}
 		clientreq.URL.Host = clientreq.URL.Hostname()
 
 		if clientreq.Header.Get("Upgrade") == "websocket" {