diff --git a/composer/report/base.py b/composer/report/base.py index 9dfe647..717d0db 100644 --- a/composer/report/base.py +++ b/composer/report/base.py @@ -136,6 +136,9 @@ class Report(Template): def getGroupFields(self): return [f for f in self.fields if 'group' in f.executionSteps] + + def getTotalsFields(self): + return [f for f in self.fields if 'totals' in f.executionSteps] class BaseQueryCriteria(Component): diff --git a/composer/report/result.py b/composer/report/result.py index 2a43153..a8292fe 100644 --- a/composer/report/result.py +++ b/composer/report/result.py @@ -76,16 +76,29 @@ class GroupHeaderRow(BaseRow): if f.name == col.name: fields[idx] = col return fields + + +class TotalsRow(BaseRow): + + def getRawValue(self, attr): + return self.data.get(attr, u'') + + @Lazy + def displayedColumns(self): + return self.parent.context.getActiveOutputFields() class ResultSet(object): def __init__(self, context, data, rowFactory=Row, headerRowFactory=GroupHeaderRow, - sortCriteria=None, queryCriteria=BaseQueryCriteria()): + totalsRowFactory=TotalsRow, sortCriteria=None, + queryCriteria=BaseQueryCriteria()): + self.context = context # the report or report instance self.data = data self.rowFactory = rowFactory self.headerRowFactory = headerRowFactory + self.totalsRowFactory = totalsRowFactory self.sortCriteria = sortCriteria self.queryCriteria = queryCriteria self.totals = BaseRow(None, self) @@ -100,12 +113,23 @@ class ResultSet(object): headerColumn.__name__ = c.output headerRow.headerColumns.append(headerColumn) return headerRow - + + def getTotalsRow(self, result, columns): + totalsRow = self.totalsRowFactory(None, self) + for c in columns: + totalsRow.data[c.name] = 0.00 + for row in result: + for c in columns: + totalsRow.data[c.name] = totalsRow.data[c.name] + c.getRawValue(row) + return totalsRow + def getResult(self): result = [self.rowFactory(item, self) for item in self.data] result = [row for row in result if self.queryCriteria.check(row)] if self.sortCriteria: result.sort(key=lambda x: [f.getSortValue(x) for f in self.sortCriteria]) + if self.totalsColumns: + self.totals = self.getTotalsRow(result, self.totalsColumns) if self.groupColumns: res = [] groupValues = [None for f in self.groupColumns] @@ -132,3 +156,7 @@ class ResultSet(object): def groupColumns(self): return self.context.getGroupFields() + @Lazy + def totalsColumns(self): + return self.context.getTotalsFields() +