Skip to content

Commit 106b5a7

Browse files
committed
Add custom filters for processing accelerators
1 parent c1bfc6e commit 106b5a7

File tree

6 files changed

+83
-13
lines changed

6 files changed

+83
-13
lines changed

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,16 @@ Each `ghw.AcceleratorDevice` struct contains the following fields:
966966
describing the processing accelerator card. This may be `nil` if no PCI device
967967
information could be determined for the card.
968968

969+
#### filters
970+
The `ghw.Accelerator()` function accepts a slice of filters, of type string, as parameter
971+
in format `[<vendor>]:[<device>][:<class>]`, (same is the _lspci_ command).
972+
973+
Some filter examples:
974+
* `::0302`. Select 3D controller cards.
975+
* `10de::0302`. Select Nvidia (`10de`) 3D controller cards (`0302`).
976+
* `1da3:1060:1200`. Select Habana Labs (`1da3`) Gaudi3 (`1060`) processing accelerator cards (`1200`).
977+
* `1002::`. Select AMD ATI hardware.
978+
969979
```go
970980
package main
971981

@@ -976,7 +986,11 @@ import (
976986
)
977987

978988
func main() {
979-
accel, err := ghw.Accelerator()
989+
filter := make([]string, 0)
990+
// example of a filter to detect 3D controllers
991+
// filter = append(filter, "::0302")
992+
993+
accel, err := ghw.Accelerator(filter)
980994
if err != nil {
981995
fmt.Printf("Error getting processing accelerator info: %v", err)
982996
}

cmd/ghwc/commands/accelerator.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ var acceleratorCmd = &cobra.Command{
2323

2424
// showAccelerator show processing accelerators information for the host system.
2525
func showAccelerator(cmd *cobra.Command, args []string) error {
26-
accel, err := ghw.Accelerator()
26+
filter := make([]string, 0)
27+
28+
accel, err := ghw.Accelerator(filter)
2729
if err != nil {
2830
return errors.Wrap(err, "error getting Accelerator info")
2931
}

host.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func Host(opts ...*WithOption) (*HostInfo, error) {
7373
if err != nil {
7474
return nil, err
7575
}
76-
acceleratorInfo, err := accelerator.New(opts...)
76+
acceleratorInfo, err := accelerator.New([]string{}, opts...)
7777
if err != nil {
7878
return nil, err
7979
}

pkg/accelerator/accelerator.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,19 @@ func (dev *AcceleratorDevice) String() string {
3737
}
3838

3939
type Info struct {
40-
ctx *context.Context
41-
Devices []*AcceleratorDevice `json:"devices"`
40+
ctx *context.Context
41+
Devices []*AcceleratorDevice `json:"devices"`
42+
DiscoveryFilters []string
4243
}
4344

4445
// New returns a pointer to an Info struct that contains information about the
4546
// accelerator devices on the host system
46-
func New(opts ...*option.Option) (*Info, error) {
47+
func New(filter []string, opts ...*option.Option) (*Info, error) {
4748
ctx := context.New(opts...)
48-
info := &Info{ctx: ctx}
49+
info := &Info{
50+
ctx: ctx,
51+
DiscoveryFilters: filter,
52+
}
4953

5054
if err := ctx.Do(info.load); err != nil {
5155
return nil, err

pkg/accelerator/accelerator_linux.go

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
package accelerator
77

88
import (
9-
"github.com/samber/lo"
9+
"fmt"
10+
"strings"
1011

1112
"github.com/jaypipes/ghw/pkg/context"
1213
"github.com/jaypipes/ghw/pkg/pci"
14+
"github.com/samber/lo"
1315
)
1416

1517
// PCI IDs list available at https://admin.pci-ids.ucw.cz/read/PD
@@ -60,13 +62,61 @@ func (i *Info) load() error {
6062
if !isAccelerator(device) {
6163
continue
6264
}
63-
accelDev := &AcceleratorDevice{
64-
Address: device.Address,
65-
PCIDevice: device,
65+
for _, filter := range i.DiscoveryFilters {
66+
if validate(filter, device) {
67+
accelDev := &AcceleratorDevice{
68+
Address: device.Address,
69+
PCIDevice: device,
70+
}
71+
accelDevices = append(accelDevices, accelDev)
72+
break
73+
}
6674
}
67-
accelDevices = append(accelDevices, accelDev)
6875
}
6976

7077
i.Devices = accelDevices
7178
return nil
7279
}
80+
81+
// validate checks if a given PCI device matches the provided filter string.
82+
//
83+
// The filter string is expected to be in the format "VendorID:ProductID:Class+Subclass".
84+
// Each part of the filter (VendorID, ProductID, Class+Subclass) is optional and can be
85+
// left empty, in which case the corresponding attribute is ignored during validation.
86+
//
87+
// Parameters:
88+
// - filter: A string in the form "VendorID:ProductID:Class+Subclass", where
89+
// any part of the string may be empty to represent a wildcard match.
90+
// - device: A pointer to a `pci.Device` structure.
91+
//
92+
// Returns:
93+
// - true: If the device matches the filter criteria (wildcards are supported).
94+
// - false: If the device does not match the filter criteria.
95+
//
96+
// Matching criteria:
97+
// - VendorID must match `device.Vendor.ID` if provided.
98+
// - ProductID must match `device.Product.ID` if provided.
99+
// - Class and Subclass must match the concatenated result of `device.Class.ID` and `device.Subclass.ID` if provided.
100+
//
101+
// Example:
102+
//
103+
// filter := "8086:1234:1200"
104+
// device := pci.Device{Vendor: Vendor{ID: "8086"}, Product: Product{ID: "1234"}, Class: Class{ID: "12"}, Subclass: Subclass{ID: "00"}}
105+
// isValid := validate(filter, &device) // returns true
106+
//
107+
// filter := "8086::1200" // Wildcard for ProductID
108+
// isValid := validate(filter, &device) // returns true
109+
//
110+
// filter := "::1200" // Wildcard for ProductID and VendorID
111+
// isValid := validate(filter, &device) // returns true
112+
func validate(filter string, device *pci.Device) bool {
113+
ids := strings.Split(filter, ":")
114+
115+
if (ids[0] == "" || ids[0] == device.Vendor.ID) &&
116+
(len(ids) < 2 || (ids[1] == "" || ids[1] == device.Product.ID)) &&
117+
(len(ids) < 3 || (ids[2] == "" || ids[2] == fmt.Sprintf("%s%s", device.Class.ID, device.Subclass.ID))) {
118+
return true
119+
}
120+
121+
return false
122+
}

pkg/accelerator/accelerator_linux_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func testScenario(t *testing.T, filename string, expectedDevs int) {
4141
_ = snapshot.Cleanup(tmpRoot)
4242
}()
4343

44-
info, err := accelerator.New(option.WithChroot(tmpRoot))
44+
info, err := accelerator.New([]string{}, option.WithChroot(tmpRoot))
4545
if err != nil {
4646
t.Fatalf("Expected nil err, but got %v", err)
4747
}

0 commit comments

Comments
 (0)