因为在项目里使用了zgrab2,所以想学习它的设计,顺便学习go语言

zgrab2主程序

主程序main.go/cmd/zgrab2/main.go下,主要流程是:

  1. 解析命令行参数
  2. 根据参数,选择是多模块扫描还是单模块,将模块初始化s.Init(f),模块名称与扫描器绑定
  3. 初始化Monitor,用于统计信息
  4. zgrab2.Process(monitor)进行扫描
  5. 最后解析扫描统计信息输出

下面的代码省略了pprof、错误处理相关代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
func main() {
_, moduleType, flag, _ := zgrab2.ParseCommandLine(os.Args[1:]) // 解析命令行
if m, ok := flag.(*zgrab2.MultipleCommand); ok {
iniParser := zgrab2.NewIniParser() // 如果使用zgrab2 multiple -c mul.ini 命令
var modTypes []string
var flagsReturned []interface{}
if m.ConfigFileName == "-" { // 使用标准输入作为配置选项的输入
modTypes, flagsReturned, _ = iniParser.Parse(os.Stdin)
} else { // 使用配置文件(.ini)作为配置选项的输入
modTypes, flagsReturned, _ = iniParser.ParseFile(m.ConfigFileName)
}
// modTypes包含配置文件中需要启用的扫描模块,比如siemens、http等
for i, fl := range flagsReturned {
f, _ := fl.(zgrab2.ScanFlags)
mod := zgrab2.GetModule(modTypes[i])
// 每个模块对应的扫描器
s := mod.NewScanner()
s.Init(f)
zgrab2.RegisterScan(s.GetName(), s)
// 在变量scanner map[string] Scanner里注册
}
} else { // 使用单一模块, 如命令 zgrab2 siemens
mod := zgrab2.GetModule(moduleType)
s := mod.NewScanner()
s.Init(flag)
zgrab2.RegisterScan(moduleType, s)
}
monitor := zgrab2.MakeMonitor() // Monitor用于统计成功失败次数
monitor.Callback = func(_ string) {
dumpHeapProfile()
}
start := time.Now()
zgrab2.Process(monitor) // 进行扫描
end := time.Now()
s := Summary{
StatusesPerModule: monitor.GetStatuses(),
StartTime: start.Format(time.RFC3339),
EndTime: end.Format(time.RFC3339),
Duration: end.Sub(start).String(),
}
enc := json.NewEncoder(zgrab2.GetMetaFile())
if err := enc.Encode(&s); err != nil { // 输出结果
log.Fatalf("unable to write summary: %s", err.Error())
}
}

processing.go

/processing.go里的Process是主要的扫描函数,做了下面的几件事:

  1. 创建扫描任务通道processQueue、输出结果通道outputQueue
  2. 创建wokers,每个workers会初始化所有扫描模块,等待从processQueue获取扫描目标,之后调用grabTarget获取信息,将结果输出到outputQueue
  3. 解析输入,也就是那些ip地址,这个config.inputTargets会在config解析是自定义,默认是/input.go里面的func InputTargetsCSV(ch chan<- ScanTarget) error

关键过程如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
//Start all the workers
for i := 0; i < workers; i++ {
go func(i int) {
for _, scannerName := range orderedScanners {
// 在orderedScanners里保存扫描模块名称, main()里注册
scanner := *scanners[scannerName]
scanner.InitPerSender(i) // 每个扫描模块的初始化,一般没有操作
}
for obj := range processQueue {
for run := uint(0); run < uint(config.ConnectionsPerHost); run++ {
result := grabTarget(obj, mon)
outputQueue <- result
}
}
workerDone.Done()
}(i)
}

grabTarget针对某一个IP(ScanTarget)遍历使用所有扫描模块,执行下面操作:

  1. 调用模块的GetTrigger,这个主要是有些模块会有执行顺序关系,比如A需要B先执行完后再执行
  2. RunScanner中调用当前模块的Scan方法,如果成功,传入成功到Monitor的结果状态通道中
  3. 之后将结果序列化后放入outputQueue
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func grabTarget(input ScanTarget, m *Monitor) []byte {
moduleResult := make(map[string]ScanResponse)

for _, scannerName := range orderedScanners {
// scannerName就是模块名称,比如modbus,siemens,crimson等
scanner := scanners[scannerName]
trigger := (*scanner).GetTrigger()
if input.Tag != trigger {
continue
}
defer func(name string) {
if e := recover(); e != nil {
log.Errorf("Panic on scanner %s when scanning target %s: %#v", scannerName, input.String(), e)
// Bubble out original error (with original stack) in lieu of explicitly logging the stack / error
panic(e)
}
}(scannerName)
name, res := RunScanner(*scanner, m, input)
moduleResult[name] = res
if res.Error != nil && !config.Multiple.ContinueOnError {
break
}
}

raw := Grab{IP: ipstr, Domain: input.Domain, Data: moduleResult}
result, err := json.Marshal(outputData) // 序列化结果
if err != nil {
log.Fatalf("unable to marshal data: %s", err)
}

return result

再往深就是每个扫描模块的Scan()的具体实现了。zgrab2提供了两个连接函数:

  1. func (target *ScanTarget) OpenUDP(flags *BaseFlags, udp *UDPFlags) (net.Conn, error)
  2. func (target *ScanTarget) Open(flags *BaseFlags) (net.Conn, error)
    主要功能是timeout、字符限制的设置

input.go

zgrab2自带一个GetTargetsCSV,用于从csv格式文件(也可以是stdin)获取输入,go自带encoding/csv/csvReader

  1. 每次csvreader.Read()获取一行数据
  2. 将数据解析为IP地址,如果是子网掩码下的地址,遍历所有可能的IP地址。
  3. 将IP地址放入processQueue
1
2
3
4
5
6
7
8
9
10
11
12
13
var ip net.IP
if ipnet != nil {
if ipnet.Mask != nil {
// expand CIDR block into one target for each IP
for ip = ipnet.IP.Mask(ipnet.Mask); ipnet.Contains(ip); incrementIP(ip) {
ch <- ScanTarget{IP: duplicateIP(ip), Domain: domain, Tag: tag}
}
continue
} else {
ip = ipnet.IP
}
}
ch <- ScanTarget{IP: ip, Domain: domain, Tag: tag}